diff --git a/dev/subtree_config.xml b/dev/subtree_config.xml index 3d230120ed7..d2e97bcbd9e 100644 --- a/dev/subtree_config.xml +++ b/dev/subtree_config.xml @@ -33,7 +33,7 @@ name="capnproto" internal_path="libs/EXTERNAL/capnproto" external_url="https://github.com/capnproto/capnproto.git" - default_external_ref="v0.9.1"/> + default_external_ref="v1.0.2"/> ./dockcross + chmod +x ./dockcross + - name: super-test + run: | + ./dockcross ./super-test.sh quick g++ MacOS: runs-on: macos-latest strategy: @@ -59,13 +101,13 @@ jobs: strategy: fail-fast: false matrix: - os: ['windows-2016', 'windows-latest'] + os: ['windows-2019', 'windows-latest'] include: - - os: windows-2016 - target: 'Visual Studio 15 2017' + - os: windows-2019 + target: 'Visual Studio 16 2019' arch: -A x64 - os: windows-latest - target: 'Visual Studio 16 2019' + target: 'Visual Studio 17 2022' arch: -A x64 steps: - uses: actions/checkout@v2 @@ -104,7 +146,7 @@ jobs: rmdir /s /q C:\PROGRA~1\POSTGR~1 echo "Building Cap'n Proto with MinGW" - cmake -Hc++ -Bbuild-output -G "MinGW Makefiles" -DCMAKE_BUILD_TYPE=debug -DCMAKE_INSTALL_PREFIX=%CD%\capnproto-c++-install -DCMAKE_SH="CMAKE_SH-NOTFOUND" + cmake -Hc++ -Bbuild-output -G "MinGW Makefiles" -DCMAKE_BUILD_TYPE=debug -DCMAKE_INSTALL_PREFIX=%CD%\capnproto-c++-install -DCMAKE_SH="CMAKE_SH-NOTFOUND" -DCMAKE_CXX_STANDARD_LIBRARIES="-static-libgcc -static-libstdc++" cmake --build build-output --target install -- -j2 echo "Building Cap'n Proto samples with MinGW" @@ -113,27 +155,49 @@ jobs: cd build-output\src ctest -V -C debug - Cygwin: - runs-on: windows-latest + # Cygwin: + # runs-on: windows-latest + # strategy: + # fail-fast: false + # steps: + # - run: git config --global core.autocrlf false + # - uses: actions/checkout@v2 + # # TODO(someday): If we could cache the Cygwin installation we wouldn't have to spend three + # # minutes installing it for every build. Unfortuntaley, actions/cache@v1 does not preserve + # # DOS file attributes, which corrupts the Cygwin install. In particular, Cygwin marks + # # symlinks with the "DOS SYSTEM" attribute. We could cache just the downloaded packages, + # # but it turns out that only saves a couple seconds; most of the time is spend unpacking. + # - name: Install Cygwin + # run: | + # choco config get cacheLocation + # choco install --no-progress cygwin + # - name: Install Cygwin additional packages + # shell: cmd + # run: | + # C:\tools\cygwin\cygwinsetup.exe -qgnNdO -R C:/tools/cygwin -l C:/tools/cygwin/packages -s http://mirrors.kernel.org/sourceware/cygwin/ -P autoconf,automake,libtool,gcc,gcc-g++,binutils,libssl-devel,make,zlib-devel,pkg-config,cmake,xxd + # - name: Build and test + # shell: cmd + # run: | + # C:\tools\cygwin\bin\bash -lc 'export PATH=/usr/local/bin:/usr/bin:/bin; cd /cygdrive/d/a/capnproto/capnproto; ./super-test.sh quick' + Linux-bazel-clang: + runs-on: ubuntu-20.04 strategy: fail-fast: false + matrix: + clang_version: [16] steps: - - run: git config --global core.autocrlf false - - uses: actions/checkout@v2 - # TODO(someday): If we could cache the Cygwin installation we wouldn't have to spend three - # minutes installing it for every build. Unfortuntaley, actions/cache@v1 does not preserve - # DOS file attributes, which corrupts the Cygwin install. In particular, Cygwin marks - # symlinks with the "DOS SYSTEM" attribute. We could cache just the downloaded packages, - # but it turns out that only saves a couple seconds; most of the time is spend unpacking. - - name: Install Cygwin - run: | - choco config get cacheLocation - choco install --no-progress cygwin - - name: Install Cygwin additional packages - shell: cmd + - uses: actions/checkout@v3 + - uses: bazelbuild/setup-bazelisk@v2 + - name: install dependencies run: | - C:\tools\cygwin\cygwinsetup.exe -qgnNdO -R C:/tools/cygwin -l C:/tools/cygwin/packages -s http://mirrors.kernel.org/sourceware/cygwin/ -P autoconf,automake,libtool,gcc,gcc-g++,binutils,libssl-devel,make,zlib-devel,pkg-config,cmake,xxd - - name: Build and test - shell: cmd + export DEBIAN_FRONTEND=noninteractive + sudo apt-get install -y build-essential git + # todo: replace with apt-get when clang-16 is part of ubuntu lts + - name: install clang + uses: egor-tensin/setup-clang@v1 + with: + version: ${{ matrix.clang_version }} + - name: super-test run: | - C:\tools\cygwin\bin\bash -lc 'export PATH=/usr/local/bin:/usr/bin:/bin; cd /cygdrive/d/a/capnproto/capnproto; ./super-test.sh quick' + cd c++ + bazel test --verbose_failures --test_output=errors //... diff --git a/libs/EXTERNAL/capnproto/.github/workflows/release-test.yml b/libs/EXTERNAL/capnproto/.github/workflows/release-test.yml index cdcedfaf73e..bd30a934222 100644 --- a/libs/EXTERNAL/capnproto/.github/workflows/release-test.yml +++ b/libs/EXTERNAL/capnproto/.github/workflows/release-test.yml @@ -9,12 +9,12 @@ on: jobs: Linux: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: fail-fast: false matrix: # We can only run extended tests with the default version of g++, because it has to match - # the verison of g++-multilib for 32-bit cross-compilation, and alternate versions of + # the version of g++-multilib for 32-bit cross-compilation, and alternate versions of # g++-multilib generally aren't available. Clang is more lenient, but we might as well be # consistent. The quick tests should be able to catch issues with older and newer compiler # versions. @@ -42,11 +42,21 @@ jobs: run: | ./super-test.sh MinGW-Wine: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: fail-fast: false steps: - uses: actions/checkout@v2 + # See: https://github.com/actions/virtual-environments/issues/4589#issuecomment-1100899313 + # GitHub's Ubuntu image installs all kinds of stuff from non-Ubuntu repositories which cause + # conflicts with Ubuntu packages ultimately preventing installation of wine32. Let's try to + # fix that... + - name: remove unwanted packages and repositories + run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo apt-get update -qq + sudo apt-get install -yqq --allow-downgrades libgd3/focal libpcre2-8-0/focal libpcre2-16-0/focal libpcre2-32-0/focal libpcre2-posix2/focal + sudo apt-get purge -yqq libmono* moby* mono* php* libgdiplus libpcre2-posix3 libzip4 - name: install dependencies run: | export DEBIAN_FRONTEND=noninteractive @@ -61,7 +71,7 @@ jobs: run: | ./super-test.sh mingw i686-w64-mingw32 cmake-packaging: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: fail-fast: false steps: @@ -83,7 +93,7 @@ jobs: run: | ./super-test.sh cmake-package cmake-static Android: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: fail-fast: false steps: diff --git a/libs/EXTERNAL/capnproto/.gitignore b/libs/EXTERNAL/capnproto/.gitignore index 812d709b4df..5714e8c2f7a 100644 --- a/libs/EXTERNAL/capnproto/.gitignore +++ b/libs/EXTERNAL/capnproto/.gitignore @@ -21,6 +21,7 @@ # Ekam build artifacts. /c++/tmp/ /c++/bin/ +/c++/deps/ # setup-ekam.sh /c++/.ekam @@ -71,6 +72,13 @@ /c++/m4/ltversion.m4 /c++/m4/lt~obsolete.m4 /c++/samples/addressbook +/c++/.cache/ # editor artefacts *~ + +# cross-compiling / glibc testing +/dockcross + +# bazel output +bazel-* diff --git a/libs/EXTERNAL/capnproto/CMakeLists.txt b/libs/EXTERNAL/capnproto/CMakeLists.txt index 17f12819f94..e1c7145abf7 100644 --- a/libs/EXTERNAL/capnproto/CMakeLists.txt +++ b/libs/EXTERNAL/capnproto/CMakeLists.txt @@ -1,3 +1,5 @@ -cmake_minimum_required(VERSION 3.4...3.13) # ! This line is edited to get rid of a CMake deprecation error +cmake_minimum_required(VERSION 3.16) project("Cap'n Proto Root" CXX) +include(CTest) + add_subdirectory(c++) diff --git a/libs/EXTERNAL/capnproto/c++/.bazelignore b/libs/EXTERNAL/capnproto/c++/.bazelignore new file mode 100644 index 00000000000..b67bdf49d4c --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/.bazelignore @@ -0,0 +1 @@ +ekam-provider \ No newline at end of file diff --git a/libs/EXTERNAL/capnproto/c++/.bazelrc b/libs/EXTERNAL/capnproto/c++/.bazelrc new file mode 100644 index 00000000000..cab5c8c13cb --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/.bazelrc @@ -0,0 +1,26 @@ +common --enable_platform_specific_config + +build:unix --cxxopt='-std=c++14' --host_cxxopt='-std=c++14' --force_pic --verbose_failures +build:unix --cxxopt='-Wall' --host_cxxopt='-Wall' +build:unix --cxxopt='-Wextra' --host_cxxopt='-Wextra' +build:unix --cxxopt='-Wno-strict-aliasing' --host_cxxopt='-Wno-strict-aliasing' +build:unix --cxxopt='-Wno-sign-compare' --host_cxxopt='-Wno-sign-compare' +build:unix --cxxopt='-Wno-unused-parameter' --host_cxxopt='-Wno-unused-parameter' + +build:linux --config=unix +build:macos --config=unix + +# See https://bazel.build/configure/windows#symlink +startup --windows_enable_symlinks +# We use LLVM's MSVC-compatible compiler driver to compile our code on Windows +# under Bazel. MSVC is natively supported when using CMake builds. +build:windows --compiler=clang-cl + +build:windows --cxxopt='/std:c++14' --host_cxxopt='/std:c++14' --verbose_failures +build:windows --cxxopt='/wo4503' --host_cxxopt='/wo4503' +# The `/std:c++14` argument is unused during boringssl compilation and we don't +# want a warning when compiling each file. +build:windows --cxxopt='-Wno-unused-command-line-argument' --host_cxxopt='-Wno-unused-command-line-argument' + +# build with ssl, zlib and bazel by default +build --//src/kj:openssl=True --//src/kj:zlib=True --//src/kj:brotli=True diff --git a/libs/EXTERNAL/capnproto/c++/.bazelversion b/libs/EXTERNAL/capnproto/c++/.bazelversion new file mode 100644 index 00000000000..5e3254243a3 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/.bazelversion @@ -0,0 +1 @@ +6.1.2 diff --git a/libs/EXTERNAL/capnproto/c++/BUILD.bazel b/libs/EXTERNAL/capnproto/c++/BUILD.bazel new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/EXTERNAL/capnproto/c++/CMakeLists.txt b/libs/EXTERNAL/capnproto/c++/CMakeLists.txt index 2acc7811582..1cc53b90ce8 100644 --- a/libs/EXTERNAL/capnproto/c++/CMakeLists.txt +++ b/libs/EXTERNAL/capnproto/c++/CMakeLists.txt @@ -1,9 +1,10 @@ -cmake_minimum_required(VERSION 3.4...3.13) # ! This line is edited to get rid of a CMake deprecation error +cmake_minimum_required(VERSION 3.16) project("Cap'n Proto" CXX) -set(VERSION 0.9.1) +set(VERSION 1.0.2) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +include(CTest) include(CheckIncludeFileCXX) include(GNUInstallDirs) if(MSVC) @@ -25,7 +26,6 @@ set(INSTALL_TARGETS_DEFAULT_ARGS # Options ====================================================================== -option(BUILD_TESTING "Build unit tests and enable CTest 'check' target." ON) option(EXTERNAL_CAPNP "Use the system capnp binary, or the one specified in $CAPNP, instead of using the compiled one." OFF) option(CAPNP_LITE "Compile Cap'n Proto in 'lite mode', in which all reflection APIs (schema.h, dynamic.h, etc.) are not included. Produces a smaller library at the cost of features. All programs built against the library must be compiled with -DCAPNP_LITE. Requires EXTERNAL_CAPNP." OFF) @@ -46,6 +46,10 @@ set(WITH_OPENSSL "AUTO" CACHE STRING # define list of values GUI will offer for the variable set_property(CACHE WITH_OPENSSL PROPERTY STRINGS AUTO ON OFF) +set(WITH_ZLIB "AUTO" CACHE STRING + "Whether or not to build libkj-gzip by linking against zlib") +set_property(CACHE WITH_ZLIB PROPERTY STRINGS AUTO ON OFF) + # shadow cache variable original value with ON/OFF, # so from now on OpenSSL-specific code just has to check: # if (WITH_OPENSSL) @@ -64,6 +68,68 @@ elseif (WITH_OPENSSL) find_package(OpenSSL REQUIRED COMPONENTS Crypto SSL) endif() +# shadow cache variable original value with ON/OFF, +# so from now on ZLIB-specific code just has to check: +# if (WITH_ZLIB) +# ... +# endif() +if(CAPNP_LITE) + set(WITH_ZLIB OFF) +elseif (WITH_ZLIB STREQUAL "AUTO") + find_package(ZLIB) + if(ZLIB_FOUND) + set(WITH_ZLIB ON) + else() + set(WITH_ZLIB OFF) + endif() +elseif (WITH_ZLIB) + find_package(ZLIB REQUIRED) +endif() + +set(WITH_FIBERS "AUTO" CACHE STRING + "Whether or not to build libkj-async with fibers") +# define list of values GUI will offer for the variable +set_property(CACHE WITH_FIBERS PROPERTY STRINGS AUTO ON OFF) + +# CapnProtoConfig.cmake.in needs this variable. +set(_WITH_LIBUCONTEXT OFF) + +if (WITH_FIBERS OR WITH_FIBERS STREQUAL "AUTO") + set(_capnp_fibers_found OFF) + if (WIN32 OR CYGWIN) + set(_capnp_fibers_found ON) + else() + # Fibers need makecontext, setcontext, getcontext, swapcontext that may be in libc, + # or in libucontext (e.g. for musl). + # We assume that makecontext implies that the others are present. + include(CheckLibraryExists) + check_library_exists(c makecontext "" HAVE_UCONTEXT_LIBC) + if (HAVE_UCONTEXT_LIBC) + set(_capnp_fibers_found ON) + else() + # Try with libucontext + find_package(PkgConfig) + if (PKG_CONFIG_FOUND) + pkg_check_modules(libucontext IMPORTED_TARGET libucontext) + if (libucontext_FOUND) + set(_WITH_LIBUCONTEXT ON) + set(_capnp_fibers_found ON) + endif() + else() + set(_capnp_fibers_found OFF) + endif() + endif() + endif() + + if (_capnp_fibers_found) + set(WITH_FIBERS ON) + elseif(WITH_FIBERS STREQUAL "AUTO") + set(WITH_FIBERS OFF) + else() + message(FATAL_ERROR "Missing 'makecontext', 'getcontext', 'setcontext' or 'swapcontext' symbol in libc and no libucontext found: KJ won't be able to build with fibers. Disable fibers (-DWITH_FIBERS=OFF).") + endif() +endif() + if(MSVC) # TODO(cleanup): Enable higher warning level in MSVC, but make sure to test # build with that warning level and clean out false positives. diff --git a/libs/EXTERNAL/capnproto/c++/Makefile.am b/libs/EXTERNAL/capnproto/c++/Makefile.am index 1e3fd8e948f..1567491d4d5 100644 --- a/libs/EXTERNAL/capnproto/c++/Makefile.am +++ b/libs/EXTERNAL/capnproto/c++/Makefile.am @@ -27,6 +27,7 @@ EXTRA_DIST = \ src/capnp/testdata/segmented-packed \ src/capnp/testdata/errors.capnp.nobuild \ src/capnp/testdata/errors2.capnp.nobuild \ + src/capnp/testdata/no-file-id.capnp.nobuild \ src/capnp/testdata/short.txt \ src/capnp/testdata/flat \ src/capnp/testdata/binary \ @@ -174,9 +175,11 @@ includekj_HEADERS = \ src/kj/async-unix.h \ src/kj/async-win32.h \ src/kj/async-io.h \ + src/kj/cidr.h \ src/kj/async-queue.h \ src/kj/main.h \ src/kj/test.h \ + src/kj/win32-api-version.h \ src/kj/windows-sanity.h includekjparse_HEADERS = \ @@ -263,6 +266,7 @@ endif libkj_la_LIBADD = $(PTHREAD_LIBS) libkj_la_LDFLAGS = -release $(SO_VERSION) -no-undefined libkj_la_SOURCES= \ + src/kj/cidr.c++ \ src/kj/common.c++ \ src/kj/units.c++ \ src/kj/memory.c++ \ @@ -305,11 +309,19 @@ libkj_async_la_SOURCES= \ src/kj/async-io-win32.c++ \ src/kj/timer.c++ +if BUILD_KJ_GZIP +libkj_http_la_LIBADD = libkj-async.la libkj.la -lz $(ASYNC_LIBS) $(PTHREAD_LIBS) +libkj_http_la_LDFLAGS = -release $(SO_VERSION) -no-undefined +libkj_http_la_SOURCES= \ + src/kj/compat/url.c++ \ + src/kj/compat/http.c++ +else libkj_http_la_LIBADD = libkj-async.la libkj.la $(ASYNC_LIBS) $(PTHREAD_LIBS) libkj_http_la_LDFLAGS = -release $(SO_VERSION) -no-undefined libkj_http_la_SOURCES= \ src/kj/compat/url.c++ \ src/kj/compat/http.c++ +endif libkj_tls_la_LIBADD = libkj-async.la libkj.la -lssl -lcrypto $(ASYNC_LIBS) $(PTHREAD_LIBS) libkj_tls_la_LDFLAGS = -release $(SO_VERSION) -no-undefined @@ -504,6 +516,7 @@ endif heavy_tests = \ src/kj/async-test.c++ \ src/kj/async-xthread-test.c++ \ + src/kj/async-coroutine-test.c++ \ src/kj/async-unix-test.c++ \ src/kj/async-unix-xthread-test.c++ \ src/kj/async-win32-test.c++ \ diff --git a/libs/EXTERNAL/capnproto/c++/WORKSPACE b/libs/EXTERNAL/capnproto/c++/WORKSPACE new file mode 100644 index 00000000000..d94a279ef31 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/WORKSPACE @@ -0,0 +1,54 @@ +workspace(name = "capnp-cpp") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//:build/load_br.bzl", "load_brotli") + +http_archive( + name = "bazel_skylib", + sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", + ], +) + +load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") + +bazel_skylib_workspace() + +http_archive( + name = "ssl", + sha256 = "873ec711658f65192e9c58554ce058d1cfa4e57e13ab5366ee16f76d1c757efc", + strip_prefix = "google-boringssl-ed2e74e", + type = "tgz", + # from master-with-bazel branch + urls = ["https://github.com/google/boringssl/tarball/ed2e74e737dc802ed9baad1af62c1514430a70d6"], +) + +# Based on https://github.com/bazelbuild/bazel/blob/master/third_party/zlib/BUILD. +_zlib_build = """ +cc_library( + name = "zlib", + srcs = glob(["*.c"]), + hdrs = glob(["*.h"]), + # Temporary workaround for zlib warnings and mac compilation, should no longer be needed with next release https://github.com/madler/zlib/issues/633 + copts = [ + "-w", + "-Dverbose=-1", + ] + select({ + "@platforms//os:macos": [ "-std=c90" ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], +) +""" + +http_archive( + name = "zlib", + build_file_content = _zlib_build, + sha256 = "d14c38e313afc35a9a8760dadf26042f51ea0f5d154b0630a31da0540107fb98", + strip_prefix = "zlib-1.2.13", + urls = ["https://zlib.net/zlib-1.2.13.tar.xz"], +) + +load_brotli() diff --git a/libs/EXTERNAL/capnproto/c++/build/configure.bzl b/libs/EXTERNAL/capnproto/c++/build/configure.bzl new file mode 100644 index 00000000000..bafee637ebb --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/build/configure.bzl @@ -0,0 +1,103 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "int_flag") + +def kj_configure(): + """Generates set of flag, settings for kj configuration. + + Creates kj-defines cc_library with all necessary preprocessor defines. + """ + + # Flags to configure KJ library build. + bool_flag( + name = "openssl", + build_setting_default = False, + ) + + bool_flag( + name = "zlib", + build_setting_default = False, + ) + + bool_flag( + name = "brotli", + build_setting_default = False, + ) + + bool_flag( + name = "libdl", + build_setting_default = False, + ) + + bool_flag( + name = "save_acquired_lock_info", + build_setting_default = False, + ) + + bool_flag( + name = "track_lock_blocking", + build_setting_default = False, + ) + + bool_flag( + name = "coroutines", + build_setting_default = False, + ) + + # Settings to use in select() expressions + native.config_setting( + name = "use_openssl", + flag_values = {"openssl": "True"}, + visibility = ["//visibility:public"], + ) + + native.config_setting( + name = "use_zlib", + flag_values = {"zlib": "True"}, + ) + + native.config_setting( + name = "use_brotli", + flag_values = {"brotli": "True"}, + ) + + native.config_setting( + name = "use_libdl", + flag_values = {"libdl": "True"}, + ) + + native.config_setting( + name = "use_coroutines", + flag_values = {"coroutines": "True"}, + ) + + native.config_setting( + name = "use_save_acquired_lock_info", + flag_values = {"save_acquired_lock_info": "True"}, + ) + + native.config_setting( + name = "use_track_lock_blocking", + flag_values = {"track_lock_blocking": "True"}, + ) + + native.cc_library( + name = "kj-defines", + defines = select({ + "//src/kj:use_openssl": ["KJ_HAS_OPENSSL"], + "//conditions:default": [], + }) + select({ + "//src/kj:use_zlib": ["KJ_HAS_ZLIB"], + "//conditions:default": [], + }) + select({ + "//src/kj:use_brotli": ["KJ_HAS_BROTLI"], + "//conditions:default": [], + }) + select({ + "//src/kj:use_libdl": ["KJ_HAS_LIBDL"], + "//conditions:default": [], + }) + select({ + "//src/kj:use_save_acquired_lock_info": ["KJ_SAVE_ACQUIRED_LOCK_INFO=1"], + "//conditions:default": ["KJ_SAVE_ACQUIRED_LOCK_INFO=0"], + }) + select({ + "//src/kj:use_track_lock_blocking": ["KJ_TRACK_LOCK_BLOCKING=1"], + "//conditions:default": ["KJ_TRACK_LOCK_BLOCKING=0"], + }), + ) diff --git a/libs/EXTERNAL/capnproto/c++/build/load_br.bzl b/libs/EXTERNAL/capnproto/c++/build/load_br.bzl new file mode 100644 index 00000000000..fe6fdfedd6b --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/build/load_br.bzl @@ -0,0 +1,12 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# Defined in a bzl file to allow dependents to pull in brotli via capnproto. Using latest brotli +# commit due to macOS compile issues with v1.0.9, switch to a release version later +def load_brotli(): + http_archive( + name = "brotli", + sha256 = "e33f397d86aaa7f3e786bdf01a7b5cff4101cfb20041c04b313b149d34332f64", + strip_prefix = "google-brotli-ed1995b", + type = "tgz", + urls = ["https://github.com/google/brotli/tarball/ed1995b6bda19244070ab5d331111f16f67c8054"], + ) diff --git a/libs/EXTERNAL/capnproto/c++/cmake/CapnProtoConfig.cmake.in b/libs/EXTERNAL/capnproto/c++/cmake/CapnProtoConfig.cmake.in index 667f502fb55..4b8ac96db4d 100644 --- a/libs/EXTERNAL/capnproto/c++/cmake/CapnProtoConfig.cmake.in +++ b/libs/EXTERNAL/capnproto/c++/cmake/CapnProtoConfig.cmake.in @@ -49,7 +49,7 @@ if(NOT _IMPORT_PREFIX) set(_IMPORT_PREFIX ${PACKAGE_PREFIX_DIR}) endif() -if (@WITH_OPENSSL@) # WITH_OPENSSL +if (@WITH_OPENSSL@) # WITH_OPENSSL include(CMakeFindDependencyMacro) if (CMAKE_VERSION VERSION_LESS 3.9) # find_dependency() did not support COMPONENTS until CMake 3.9 @@ -62,6 +62,43 @@ if (@WITH_OPENSSL@) # WITH_OPENSSL endif() endif() +if (@WITH_ZLIB@) # WITH_ZLIB + include(CMakeFindDependencyMacro) + find_dependency(ZLIB) +endif() + +if (@_WITH_LIBUCONTEXT@) # _WITH_LIBUCONTEXT + set(forwarded_config_flags) + if(CapnProto_FIND_QUIETLY) + list(APPEND forwarded_config_flags QUIET) + endif() + if(CapnProto_FIND_REQUIRED) + list(APPEND forwarded_config_flags REQUIRED) + endif() + # If the consuming project called find_package(CapnProto) with the QUIET or REQUIRED flags, forward + # them to calls to find_package(PkgConfig) and pkg_check_modules(). Note that find_dependency() + # would do this for us in the former case, but there is no such forwarding wrapper for + # pkg_check_modules(). + + find_package(PkgConfig ${forwarded_config_flags}) + if(NOT ${PkgConfig_FOUND}) + # If we're here, the REQUIRED flag must not have been passed, else we would have had a fatal + # error. Nevertheless, a diagnostic for this case is probably nice. + if(NOT CapnProto_FIND_QUIETLY) + message(WARNING "pkg-config cannot be found") + endif() + set(CapnProto_FOUND OFF) + return() + endif() + + if (CMAKE_VERSION VERSION_LESS 3.6) + # CMake >= 3.6 required due to the use of IMPORTED_TARGET + message(SEND_ERROR "libucontext support requires CMake >= 3.6.") + endif() + + pkg_check_modules(libucontext IMPORTED_TARGET ${forwarded_config_flags} libucontext) +endif() + include("${CMAKE_CURRENT_LIST_DIR}/CapnProtoTargets.cmake") include("${CMAKE_CURRENT_LIST_DIR}/CapnProtoMacros.cmake") diff --git a/libs/EXTERNAL/capnproto/c++/compile_flags.txt b/libs/EXTERNAL/capnproto/c++/compile_flags.txt new file mode 100644 index 00000000000..a9e8a51665d --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/compile_flags.txt @@ -0,0 +1,15 @@ +-std=c++20 +-Isrc +-Itmp +-isystem/usr/local/include +-isystem/usr/include/x86_64-linux-gnu +-isystem/usr/include +-DKJ_HEADER_WARNINGS +-DCAPNP_HEADER_WARNINGS +-DCAPNP_DEBUG_TYPES +-DKJ_HAS_OPENSSL +-DKJ_HAS_LIBDL +-DKJ_HAS_ZLIB +-DKJ_HAS_BROTLI +-DKJ_BENCHMARK_MALLOC +-xc++ diff --git a/libs/EXTERNAL/capnproto/c++/configure.ac b/libs/EXTERNAL/capnproto/c++/configure.ac index 72fe8456f19..ba4b4038d12 100644 --- a/libs/EXTERNAL/capnproto/c++/configure.ac +++ b/libs/EXTERNAL/capnproto/c++/configure.ac @@ -1,6 +1,6 @@ ## Process this file with autoconf to produce configure. -AC_INIT([Capn Proto],[0.9.1],[capnproto@googlegroups.com],[capnproto-c++]) +AC_INIT([Capn Proto],[1.0.2],[capnproto@googlegroups.com],[capnproto-c++]) AC_CONFIG_SRCDIR([src/capnp/layout.c++]) AC_CONFIG_AUX_DIR([build-aux]) @@ -32,6 +32,11 @@ AC_ARG_WITH([openssl], [build libkj-tls by linking against openssl @<:@default=check@:>@])], [],[with_openssl=check]) +AC_ARG_WITH([fibers], + [AS_HELP_STRING([--with-fibers], + [build libkj-async with fibers @<:@default=check@:>@])], + [],[with_fibers=check]) + AC_ARG_ENABLE([reflection], [ AS_HELP_STRING([--disable-reflection], [ compile Cap'n Proto in "lite mode", in which all reflection APIs (schema.h, dynamic.h, etc.) @@ -195,8 +200,71 @@ AS_IF([test "$with_openssl" != no], [ ]) AM_CONDITIONAL([BUILD_KJ_TLS], [test "$with_openssl" != no]) -# CapnProtoConfig.cmake.in needs this variable. -AC_SUBST(WITH_OPENSSL, $with_openssl) +# Fibers don't work if exceptions are disabled, so default off in that case. +AS_IF([test "$with_fibers" != no], [ + AC_MSG_CHECKING([if exceptions are enabled]) + AC_COMPILE_IFELSE([void foo() { throw 1; }], [ + AC_MSG_RESULT([yes]) + ], [ + AS_IF([test "$with_fibers" = check], [ + AC_MSG_RESULT([no -- therefore, disabling fibers]) + with_fibers=no + ], [ + AC_MSG_RESULT([no]) + AC_MSG_ERROR([Fibers require exceptions, but your compiler flags disable exceptions. Please either enable exceptions or disable fibers (--without-fibers).]) + ]) + ]) +]) + +# Check for library support necessary for fibers. +AS_IF([test "$with_fibers" != no], [ + case "${host_os}" in + cygwin* | mingw* ) + # Fibers always work on Windows, where there's an explicit API for them. + with_fibers=yes + ;; + * ) + # Fibers need the symbols getcontext, setcontext, swapcontext and makecontext. + # We assume that makecontext implies the rest. + libc_supports_fibers=yes + AC_SEARCH_LIBS([makecontext], [], [], [ + libc_supports_fibers=no + ]) + + AS_IF([test "$libc_supports_fibers" = yes], [ + with_fibers=yes + ], [ + # If getcontext does not exist in libc, try with libucontext + ucontext_supports_fibers=yes + AC_CHECK_LIB(ucontext, [makecontext], [], [ + ucontext_supports_fibers=no + ]) + AS_IF([test "$ucontext_supports_fibers" = yes], [ + ASYNC_LIBS="$ASYNC_LIBS -lucontext" + with_fibers=yes + ], [ + AS_IF([test "$with_fibers" = yes], [ + AC_MSG_ERROR([Missing symbols required for fibers (makecontext, setcontext, ...). Disable fibers (--without-fibers) or install libucontext]) + ], [ + AC_MSG_WARN([could not find required symbols (makecontext, setcontext, ...) -- won't build with fibers]) + with_fibers=no + ]) + ]) + ]) + ;; + esac +]) +AS_IF([test "$with_fibers" = yes], [ + CXXFLAGS="$CXXFLAGS -DKJ_USE_FIBERS" +], [ + CXXFLAGS="$CXXFLAGS -DKJ_USE_FIBERS=0" +]) + +# CapnProtoConfig.cmake.in needs these variables, +# we force them to NO because we don't need the CMake dependency for them, +# the dependencies are provided by the .pc files. +AC_SUBST(WITH_OPENSSL, NO) +AC_SUBST(_WITH_LIBUCONTEXT, NO) AM_CONDITIONAL([HAS_FUZZING_ENGINE], [test "x$LIB_FUZZING_ENGINE" != "x"]) diff --git a/libs/EXTERNAL/capnproto/c++/ekam-build.sh b/libs/EXTERNAL/capnproto/c++/ekam-build.sh new file mode 100755 index 00000000000..eff565a44c9 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/ekam-build.sh @@ -0,0 +1,70 @@ +#! /bin/bash +# +# This file builds Cap'n Proto using Ekam. + +set -euo pipefail + +NPROC=$(nproc) + +if [ ! -e deps/ekam ]; then + mkdir -p deps + git clone https://github.com/capnproto/ekam.git deps/ekam +fi + +if [ ! -e deps/ekam/deps/capnproto ]; then + mkdir -p deps/ekam/deps + ln -s ../../../.. deps/ekam/deps/capnproto +fi + +if [ ! -e deps/ekam/ekam ]; then + (cd deps/ekam && make -j$NPROC) +fi + +OPT_CXXFLAGS= +EXTRA_LIBS= +EKAM_FLAGS= + +while [ $# -gt 0 ]; do + case $1 in + dbg | debug ) + OPT_CXXFLAGS="-g -DCAPNP_DEBUG_TYPES " + ;; + opt | release ) + OPT_CXXFLAGS="-DNDEBUG -O2 -g" + ;; + prof | profile ) + OPT_CXXFLAGS="-DNDEBUG -O2 -g" + EXTRA_LIBS="$EXTRA_LIBS -lprofiler" + ;; + tcmalloc ) + EXTRA_LIBS="$EXTRA_LIBS -ltcmalloc" + ;; + continuous ) + EKAM_FLAGS="-c -n :41315" + ;; + * ) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac + shift +done + +CLANG_CXXFLAGS="-std=c++20 -stdlib=libc++ -pthread -Wall -Wextra -Werror -Wno-strict-aliasing -Wno-sign-compare -Wno-unused-parameter -Wimplicit-fallthrough -Wno-error=unused-command-line-argument -Wno-missing-field-initializers -DKJ_HEADER_WARNINGS -DCAPNP_HEADER_WARNINGS -DKJ_HAS_OPENSSL -DKJ_HAS_LIBDL -DKJ_HAS_ZLIB -DKJ_BENCHMARK_MALLOC" + +export CXX=${CXX:-clang++} +export CC=${CC:-clang} +export LIBS="-lz -ldl -lcrypto -lssl -stdlib=libc++ $EXTRA_LIBS -pthread" +export CXXFLAGS=${CXXFLAGS:-$OPT_CXXFLAGS $CLANG_CXXFLAGS} + +# TODO(someday): Get the protobuf benchmarks working. For now these settings will prevent build +# errors in the benchmarks directory. Note that it's tricky to link against an installed copy +# of libprotobuf because we have to use compatible C++ standard libraries. We either need to +# build libprotobuf from source using libc++, or we need to switch back to libstdc++ when +# enabling libprotobuf. Arguably building from source would be more fair so we can match compiler +# flags for performance comparison purposes, but we'll have to see if ekam is still able to build +# libprotobuf these days... +CXXFLAGS="$CXXFLAGS -DCAPNP_NO_PROTOBUF_BENCHMARK" +export PROTOC=/bin/true + +exec deps/ekam/bin/ekam $EKAM_FLAGS -j$NPROC diff --git a/libs/EXTERNAL/capnproto/c++/samples/CMakeLists.txt b/libs/EXTERNAL/capnproto/c++/samples/CMakeLists.txt index 6a36b175204..d80b5303c0d 100644 --- a/libs/EXTERNAL/capnproto/c++/samples/CMakeLists.txt +++ b/libs/EXTERNAL/capnproto/c++/samples/CMakeLists.txt @@ -17,7 +17,7 @@ # cmake --build . project("Cap'n Proto Samples" CXX) -cmake_minimum_required(VERSION 3.1) +cmake_minimum_required(VERSION 3.16) find_package(CapnProto CONFIG REQUIRED) diff --git a/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-carsales.c++ b/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-carsales.c++ index 40477097abe..7190251a0f7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-carsales.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-carsales.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if !CAPNP_NO_PROTOBUF_BENCHMARK + #include "carsales.pb.h" #include "protobuf-common.h" @@ -139,3 +141,5 @@ int main(int argc, char* argv[]) { capnp::benchmark::protobuf::BenchmarkTypes, capnp::benchmark::protobuf::CarSalesTestCase>(argc, argv); } + +#endif // !CAPNP_NO_PROTOBUF_BENCHMARK diff --git a/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-catrank.c++ b/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-catrank.c++ index a3036b237d4..648bdb7cd4a 100644 --- a/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-catrank.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-catrank.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if !CAPNP_NO_PROTOBUF_BENCHMARK + #include "catrank.pb.h" #include "protobuf-common.h" @@ -128,3 +130,5 @@ int main(int argc, char* argv[]) { capnp::benchmark::protobuf::BenchmarkTypes, capnp::benchmark::protobuf::CatRankTestCase>(argc, argv); } + +#endif // !CAPNP_NO_PROTOBUF_BENCHMARK diff --git a/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-eval.c++ b/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-eval.c++ index db27b7378a2..b197a0ea719 100644 --- a/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-eval.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/benchmark/protobuf-eval.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if !CAPNP_NO_PROTOBUF_BENCHMARK + #include "eval.pb.h" #include "protobuf-common.h" @@ -116,3 +118,5 @@ int main(int argc, char* argv[]) { capnp::benchmark::protobuf::BenchmarkTypes, capnp::benchmark::protobuf::ExpressionTestCase>(argc, argv); } + +#endif // !CAPNP_NO_PROTOBUF_BENCHMARK diff --git a/libs/EXTERNAL/capnproto/c++/src/benchmark/runner.c++ b/libs/EXTERNAL/capnproto/c++/src/benchmark/runner.c++ index 5ff07567d11..155324921a7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/benchmark/runner.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/benchmark/runner.c++ @@ -186,7 +186,7 @@ TestResult runTest(Product product, TestCase testCase, Mode mode, Reuse reuse, } char itersStr[64]; - sprintf(itersStr, "%llu", (long long unsigned int)iters); + snprintf(itersStr, sizeof(itersStr), "%llu", (long long unsigned int)iters); argv[4] = itersStr; argv[5] = nullptr; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/BUILD.bazel b/libs/EXTERNAL/capnproto/c++/src/capnp/BUILD.bazel new file mode 100644 index 00000000000..f11a7a24aa2 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/BUILD.bazel @@ -0,0 +1,278 @@ +load("@capnp-cpp//src/capnp:cc_capnp_library.bzl", "cc_capnp_library") + +cc_library( + name = "capnp", + srcs = [ + "any.c++", + "arena.c++", + "blob.c++", + "c++.capnp.c++", + "dynamic.c++", + "layout.c++", + "list.c++", + "message.c++", + "schema.c++", + "schema.capnp.c++", + "schema-loader.c++", + "serialize.c++", + "serialize-packed.c++", + "stream.capnp.c++", + "stringify.c++", + ], + hdrs = [ + "any.h", + "arena.h", + "blob.h", + "c++.capnp.h", + "capability.h", + "common.h", + "dynamic.h", + "endian.h", + "generated-header-support.h", + "layout.h", + "list.h", + "membrane.h", + "message.h", + "orphan.h", + "pointer-helpers.h", + "pretty-print.h", + "raw-schema.h", + "schema.capnp.h", + "schema.h", + "schema-lite.h", + "schema-loader.h", + "schema-parser.h", + "serialize.h", + "serialize-async.h", + "serialize-packed.h", + "serialize-text.h", + "stream.capnp.h", + ], + include_prefix = "capnp", + visibility = ["//visibility:public"], + deps = [ + "//src/kj:kj-async", + ], +) + +cc_library( + name = "capnp-rpc", + srcs = [ + "capability.c++", + "dynamic-capability.c++", + "ez-rpc.c++", + "membrane.c++", + "persistent.capnp.c++", + "reconnect.c++", + "rpc.c++", + "rpc.capnp.c++", + "rpc-twoparty.c++", + "rpc-twoparty.capnp.c++", + "serialize-async.c++", + ], + hdrs = [ + "ez-rpc.h", + "persistent.capnp.h", + "reconnect.h", + "rpc.capnp.h", + "rpc.h", + "rpc-prelude.h", + "rpc-twoparty.capnp.h", + "rpc-twoparty.h", + ], + include_prefix = "capnp", + visibility = ["//visibility:public"], + deps = [ + ":capnp", + ], +) + +cc_library( + name = "capnpc", + srcs = [ + "compiler/compiler.c++", + "compiler/error-reporter.c++", + "compiler/generics.c++", + "compiler/grammar.capnp.c++", + "compiler/lexer.c++", + "compiler/lexer.capnp.c++", + "compiler/node-translator.c++", + "compiler/parser.c++", + "compiler/type-id.c++", + "schema-parser.c++", + "serialize-text.c++", + ], + hdrs = [ + "compiler/compiler.h", + "compiler/error-reporter.h", + "compiler/generics.h", + "compiler/grammar.capnp.h", + "compiler/lexer.capnp.h", + "compiler/lexer.h", + "compiler/module-loader.h", + "compiler/node-translator.h", + "compiler/parser.h", + "compiler/resolver.h", + "compiler/type-id.h", + ], + include_prefix = "capnp", + visibility = ["//visibility:public"], + deps = [ + ":capnp", + ], +) + +cc_binary( + name = "capnp_tool", + srcs = [ + "compiler/capnp.c++", + "compiler/module-loader.c++", + ], + visibility = ["//visibility:public"], + deps = [ + ":capnpc", + "//src/capnp/compat:json", + ], +) + +cc_binary( + name = "capnpc-c++", + srcs = [ + "compiler/capnpc-c++.c++", + ], + visibility = ["//visibility:public"], + deps = [ + ":capnpc", + ], +) + +cc_binary( + name = "capnpc-capnp", + srcs = [ + "compiler/capnpc-capnp.c++", + ], + visibility = ["//visibility:public"], + deps = [ + ":capnpc", + ], +) + +# capnp files that are implicitly available for import to any .capnp. +filegroup( + name = "capnp_system_library", + srcs = [ + "c++.capnp", + "schema.capnp", + "stream.capnp", + "//src/capnp/compat:json.capnp", + ], + visibility = ["//visibility:public"], +) + +# library to link with every cc_capnp_library +cc_library( + name = "capnp_runtime", + visibility = ["//visibility:public"], + # include json since it is not exposed as cc_capnp_library + deps = [ + ":capnp", + "//src/capnp/compat:json", + ], +) + +filegroup( + name = "testdata", + srcs = glob(["testdata/**/*"]), +) + +cc_capnp_library( + name = "capnp_test", + srcs = [ + "test.capnp", + "test-import.capnp", + "test-import2.capnp", + ], + data = [ + "c++.capnp", + "schema.capnp", + "stream.capnp", + ":testdata", + ], + include_prefix = "capnp", + src_prefix = "src", +) + +cc_library( + name = "capnp-test", + srcs = ["test-util.c++"], + hdrs = ["test-util.h"], + deps = [ + ":capnp-rpc", + ":capnp_test", + ":capnpc", + "//src/kj:kj-test", + ], + visibility = [":__subpackages__" ] +) + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + deps = [":capnp-test"], +) for f in [ + "any-test.c++", + "blob-test.c++", + "canonicalize-test.c++", + "common-test.c++", + "capability-test.c++", + "compiler/lexer-test.c++", + "compiler/type-id-test.c++", + "dynamic-test.c++", + "encoding-test.c++", + "endian-test.c++", + "ez-rpc-test.c++", + "layout-test.c++", + "membrane-test.c++", + "message-test.c++", + "orphan-test.c++", + "reconnect-test.c++", + "rpc-test.c++", + "rpc-twoparty-test.c++", + "schema-test.c++", + "schema-loader-test.c++", + "schema-parser-test.c++", + "serialize-async-test.c++", + "serialize-packed-test.c++", + "serialize-test.c++", + "serialize-text-test.c++", + "stringify-test.c++", +]] + +cc_test( + name = "endian-reverse-test", + srcs = ["endian-reverse-test.c++"], + deps = [":capnp-test"], + target_compatible_with = select({ + "@platforms//os:windows": ["@platforms//:incompatible"], + "//conditions:default": [], + }), +) + +cc_library( + name = "endian-test-base", + hdrs = ["endian-test.c++"], + deps = [":capnp-test"], +) + +cc_test( + name = "endian-fallback-test", + srcs = ["endian-fallback-test.c++"], + deps = [":endian-test-base"], +) + +cc_test( + name = "fuzz-test", + size = "large", + srcs = ["fuzz-test.c++"], + deps = [":capnp-test"], +) diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/CMakeLists.txt b/libs/EXTERNAL/capnproto/c++/src/capnp/CMakeLists.txt index 3b515507604..9980fde617f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/CMakeLists.txt +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/CMakeLists.txt @@ -66,8 +66,9 @@ add_library(capnp ${capnp_sources}) add_library(CapnProto::capnp ALIAS capnp) target_link_libraries(capnp PUBLIC kj) #make sure external consumers don't need to manually set the include dirs +get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) target_include_directories(capnp INTERFACE - $ + $ $ ) # Ensure the library has a version set to match autotools build @@ -212,7 +213,7 @@ if(NOT CAPNP_LITE) install(TARGETS capnp_tool capnpc_cpp capnpc_capnp ${INSTALL_TARGETS_DEFAULT_ARGS}) if(WIN32) - # On Windows platforms symlinks are not guranteed to support. Also differnt version of CMake handle create_symlink in a different way. + # On Windows platforms symlinks are not guaranteed to support. Also different version of CMake handle create_symlink in a different way. # The most portable way in this case just copy the file. install(CODE "execute_process(COMMAND \"${CMAKE_COMMAND}\" -E copy \"\$ENV{DESTDIR}${CMAKE_INSTALL_FULL_BINDIR}/capnp${CMAKE_EXECUTABLE_SUFFIX}\" \"\$ENV{DESTDIR}${CMAKE_INSTALL_FULL_BINDIR}/capnpc${CMAKE_EXECUTABLE_SUFFIX}\")") else() diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/arena.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/arena.c++ index 58dd07faf5a..77061db1d7f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/arena.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/arena.c++ @@ -104,7 +104,7 @@ ReaderArena::~ReaderArena() noexcept(false) {} size_t ReaderArena::sizeInWords() { size_t total = segment0.getArray().size(); - for (uint i = 0; ; i++) { + for (uint i = 1; ; i++) { SegmentReader* segment = tryGetSegment(SegmentId(i)); if (segment == nullptr) return total; total += unboundAs(segment->getSize() / WORDS); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/arena.h b/libs/EXTERNAL/capnproto/c++/src/capnp/arena.h index be4785c90c4..aeaff8448d0 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/arena.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/arena.h @@ -99,7 +99,7 @@ class ReadLimiter { // alignas(8) is the default on 64-bit systems, but needed on 32-bit to avoid an expensive // unaligned atomic operation. - KJ_DISALLOW_COPY(ReadLimiter); + KJ_DISALLOW_COPY_AND_MOVE(ReadLimiter); KJ_ALWAYS_INLINE(void setLimit(uint64_t newLimit)) { #if defined(__GNUC__) || defined(__clang__) @@ -174,7 +174,7 @@ class SegmentReader { kj::ArrayPtr ptr; // size guaranteed to fit in SEGMENT_WORD_COUNT_BITS bits ReadLimiter* readLimiter; - KJ_DISALLOW_COPY(SegmentReader); + KJ_DISALLOW_COPY_AND_MOVE(SegmentReader); friend class SegmentBuilder; @@ -226,7 +226,7 @@ class SegmentBuilder: public SegmentReader { [[noreturn]] void throwNotWritable(); - KJ_DISALLOW_COPY(SegmentBuilder); + KJ_DISALLOW_COPY_AND_MOVE(SegmentBuilder); }; class Arena { @@ -246,7 +246,7 @@ class ReaderArena final: public Arena { public: explicit ReaderArena(MessageReader* message); ~ReaderArena() noexcept(false); - KJ_DISALLOW_COPY(ReaderArena); + KJ_DISALLOW_COPY_AND_MOVE(ReaderArena); size_t sizeInWords(); @@ -282,7 +282,7 @@ class BuilderArena final: public Arena { explicit BuilderArena(MessageBuilder* message); BuilderArena(MessageBuilder* message, kj::ArrayPtr segments); ~BuilderArena() noexcept(false); - KJ_DISALLOW_COPY(BuilderArena); + KJ_DISALLOW_COPY_AND_MOVE(BuilderArena); size_t sizeInWords(); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/blob.h b/libs/EXTERNAL/capnproto/c++/src/capnp/blob.h index 7e24e18cea5..451e443dc43 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/blob.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/blob.h @@ -173,8 +173,8 @@ inline kj::StringPtr KJ_STRINGIFY(Text::Builder builder) { return builder.asString(); } -inline bool operator==(const char* a, const Text::Builder& b) { return a == b.asString(); } -inline bool operator!=(const char* a, const Text::Builder& b) { return a != b.asString(); } +inline bool operator==(const char* a, const Text::Builder& b) { return b.asString() == a; } +inline bool operator!=(const char* a, const Text::Builder& b) { return b.asString() != a; } inline Text::Builder::operator kj::StringPtr() const { return kj::StringPtr(content.begin(), content.size() - 1); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp b/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp index 2bda5471792..9eaff6d1d28 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp @@ -24,3 +24,25 @@ $namespace("capnp::annotations"); annotation namespace(file): Text; annotation name(field, enumerant, struct, enum, interface, method, param, group, union): Text; + +annotation allowCancellation(interface, method, file) :Void; +# Indicates that the server-side implementation of a method is allowed to be canceled when the +# client requests cancellation. Without this annotation, once a method call has been delivered to +# the server-side application code, any requests by the client to cancel it will be ignored, and +# the method will run to completion anyway. This applies even for local in-process calls. +# +# This behavior applies specifically to implementations that inherit from the C++ `Foo::Server` +# interface. The annotation won't affect DynamicCapability::Server implementations; they must set +# the cancellation mode at runtime. +# +# When applied to an interface rather than an individual method, the annotation applies to all +# methods in the interface. When applied to a file, it applies to all methods defined in the file. +# +# It's generally recommended that this annotation be applied to all methods. However, when doing +# so, it is important that the server implementation use cancellation-safe code. See: +# +# https://github.com/capnproto/capnproto/blob/master/kjdoc/tour.md#cancellation +# +# If your code is not cancellation-safe, then allowing cancellation might give a malicious client +# an easy way to induce use-after-free or other bugs in your server, by requesting cancellation +# when not expected. diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp.c++ index 576d733b23d..02378a9c889 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp.c++ @@ -32,7 +32,7 @@ static const ::capnp::_::AlignedData<21> b_b9c6f99ebf805f2c = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_b9c6f99ebf805f2c = { 0xb9c6f99ebf805f2c, b_b9c6f99ebf805f2c.words, 21, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_b9c6f99ebf805f2c, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_b9c6f99ebf805f2c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<20> b_f264a779fef191ce = { @@ -61,7 +61,38 @@ static const ::capnp::_::AlignedData<20> b_f264a779fef191ce = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_f264a779fef191ce = { 0xf264a779fef191ce, b_f264a779fef191ce.words, 20, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_f264a779fef191ce, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_f264a779fef191ce, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<22> b_ac7096ff8cfc9dce = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 206, 157, 252, 140, 255, 150, 112, 172, + 16, 0, 0, 0, 5, 0, 1, 3, + 129, 78, 48, 184, 123, 125, 248, 189, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 18, 1, 0, 0, + 37, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 32, 0, 0, 0, 3, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 43, + 43, 46, 99, 97, 112, 110, 112, 58, + 97, 108, 108, 111, 119, 67, 97, 110, + 99, 101, 108, 108, 97, 116, 105, 111, + 110, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_ac7096ff8cfc9dce = b_ac7096ff8cfc9dce.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_ac7096ff8cfc9dce = { + 0xac7096ff8cfc9dce, b_ac7096ff8cfc9dce.words, 22, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_ac7096ff8cfc9dce, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp.h index 73d35cad904..6fc28fb69a6 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/c++.capnp.h @@ -6,7 +6,9 @@ #include #include -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif @@ -18,6 +20,7 @@ namespace schemas { CAPNP_DECLARE_SCHEMA(b9c6f99ebf805f2c); CAPNP_DECLARE_SCHEMA(f264a779fef191ce); +CAPNP_DECLARE_SCHEMA(ac7096ff8cfc9dce); } // namespace schemas } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/capability-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/capability-test.c++ index 6a934a94c83..a645abce40b 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/capability-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/capability-test.c++ @@ -280,7 +280,7 @@ TEST(Capability, TailCall) { } TEST(Capability, AsyncCancelation) { - // Tests allowCancellation(). + // Tests cancellation. kj::EventLoop loop; kj::WaitScope waitScope(loop); @@ -1304,10 +1304,8 @@ KJ_TEST("Streaming calls can be canceled") { auto promise4 = cap.finishStreamRequest().send(); - // Cancel the streaming calls. - promise1 = nullptr; + // Cancel the doStreamJ() request. promise2 = nullptr; - promise3 = nullptr; KJ_EXPECT(server.iSum == 0); KJ_EXPECT(server.jSum == 0); @@ -1321,10 +1319,9 @@ KJ_TEST("Streaming calls can be canceled") { KJ_EXPECT(!promise4.poll(waitScope)); - // The call to doStreamJ() opted into cancellation so the next call to doStreamI() happens - // immediately. + // The call to doStreamJ() was canceled, so the next call to doStreamI() happens immediately. KJ_EXPECT(server.iSum == 579); - KJ_EXPECT(server.jSum == 321); + KJ_EXPECT(server.jSum == 0); KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); @@ -1332,7 +1329,7 @@ KJ_TEST("Streaming calls can be canceled") { auto result = promise4.wait(waitScope); KJ_EXPECT(result.getTotalI() == 579); - KJ_EXPECT(result.getTotalJ() == 321); + KJ_EXPECT(result.getTotalJ() == 0); } KJ_TEST("Streaming call throwing cascades to following calls") { @@ -1393,6 +1390,35 @@ KJ_TEST("Streaming call throwing cascades to following calls") { KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("throw requested", promise4.ignoreResult().wait(waitScope)); } +KJ_TEST("RevocableServer") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + class ServerImpl: public test::TestMembrane::Server { + public: + kj::Promise waitForever(WaitForeverContext context) override { + return kj::NEVER_DONE; + } + }; + + ServerImpl server; + + RevocableServer revocable(server); + + auto promise = revocable.getClient().waitForeverRequest().send(); + KJ_EXPECT(!promise.poll(waitScope)); + + revocable.revoke(); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "capability was revoked", + promise.ignoreResult().wait(waitScope)); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "capability was revoked", + revocable.getClient().waitForeverRequest().send().ignoreResult().wait(waitScope)); +} + } // namespace } // namespace _ } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/capability.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/capability.c++ index dc5f7f2e3bf..9462c6c218e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/capability.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/capability.c++ @@ -91,7 +91,7 @@ Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented return { KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.", actualInterfaceName, requestedTypeId), - false + false, true }; } @@ -99,7 +99,7 @@ Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented const char* interfaceName, uint64_t typeId, uint16_t methodId) { return { KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName, typeId, methodId), - false + false, true }; } @@ -135,7 +135,7 @@ static inline uint firstSegmentSize(kj::Maybe sizeHint) { } } -class LocalResponse final: public ResponseHook, public kj::Refcounted { +class LocalResponse final: public ResponseHook { public: LocalResponse(kj::Maybe sizeHint) : message(firstSegmentSize(sizeHint)) {} @@ -146,9 +146,9 @@ public: class LocalCallContext final: public CallContextHook, public ResponseHook, public kj::Refcounted { public: LocalCallContext(kj::Own&& request, kj::Own clientRef, - kj::Own> cancelAllowedFulfiller) - : request(kj::mv(request)), clientRef(kj::mv(clientRef)), - cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {} + ClientHook::CallHints hints, bool isStreaming) + : request(kj::mv(request)), clientRef(kj::mv(clientRef)), hints(hints), + isStreaming(isStreaming) {} AnyPointer::Reader getParams() override { KJ_IF_MAYBE(r, request) { @@ -162,7 +162,7 @@ public: } AnyPointer::Builder getResults(kj::Maybe sizeHint) override { if (response == nullptr) { - auto localResponse = kj::refcounted(sizeHint); + auto localResponse = kj::heap(sizeHint); responseBuilder = localResponse->message.getRoot(); response = Response(responseBuilder.asReader(), kj::mv(localResponse)); } @@ -183,22 +183,31 @@ public: ClientHook::VoidPromiseAndPipeline directTailCall(kj::Own&& request) override { KJ_REQUIRE(response == nullptr, "Can't call tailCall() after initializing the results struct."); - auto promise = request->send(); + if (hints.onlyPromisePipeline) { + return { + kj::NEVER_DONE, + PipelineHook::from(request->sendForPipeline()) + }; + } - auto voidPromise = promise.then([this](Response&& tailResponse) { - response = kj::mv(tailResponse); - }); + if (isStreaming) { + auto promise = request->sendStreaming(); + return { kj::mv(promise), getDisabledPipeline() }; + } else { + auto promise = request->send(); - return { kj::mv(voidPromise), PipelineHook::from(kj::mv(promise)) }; + auto voidPromise = promise.then([this](Response&& tailResponse) { + response = kj::mv(tailResponse); + }); + + return { kj::mv(voidPromise), PipelineHook::from(kj::mv(promise)) }; + } } kj::Promise onTailCall() override { auto paf = kj::newPromiseAndFulfiller(); tailCallPipelineFulfiller = kj::mv(paf.fulfiller); return kj::mv(paf.promise); } - void allowCancellation() override { - cancelAllowedFulfiller->fulfill(); - } kj::Own addRef() override { return kj::addRef(*this); } @@ -208,39 +217,63 @@ public: AnyPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null kj::Own clientRef; kj::Maybe>> tailCallPipelineFulfiller; - kj::Own> cancelAllowedFulfiller; + ClientHook::CallHints hints; + bool isStreaming; }; class LocalRequest final: public RequestHook { public: inline LocalRequest(uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint, kj::Own client) + kj::Maybe sizeHint, ClientHook::CallHints hints, + kj::Own client) : message(kj::heap(firstSegmentSize(sizeHint))), - interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {} + interfaceId(interfaceId), methodId(methodId), hints(hints), client(kj::mv(client)) {} RemotePromise send() override { - KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request."); + bool isStreaming = false; + return sendImpl(isStreaming); + } - auto cancelPaf = kj::newPromiseAndFulfiller(); + kj::Promise sendStreaming() override { + // We don't do any special handling of streaming in RequestHook for local requests, because + // there is no latency to compensate for between the client and server in this case. However, + // we record whether the call was streaming, so that it can be preserved as a streaming call + // if the local capability later resolves to a remote capability. + bool isStreaming = true; + return sendImpl(isStreaming).ignoreResult(); + } + AnyPointer::Pipeline sendForPipeline() override { + KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request."); + + hints.onlyPromisePipeline = true; + bool isStreaming = false; auto context = kj::refcounted( - kj::mv(message), client->addRef(), kj::mv(cancelPaf.fulfiller)); - auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context)); + kj::mv(message), client->addRef(), hints, isStreaming); + auto vpap = client->call(interfaceId, methodId, kj::addRef(*context), hints); + return AnyPointer::Pipeline(kj::mv(vpap.pipeline)); + } + + const void* getBrand() override { + return nullptr; + } - // We have to make sure the call is not canceled unless permitted. We need to fork the promise - // so that if the client drops their copy, the promise isn't necessarily canceled. - auto forked = promiseAndPipeline.promise.fork(); + kj::Own message; - // We daemonize one branch, but only after joining it with the promise that fires if - // cancellation is allowed. - forked.addBranch() - .attach(kj::addRef(*context)) - .exclusiveJoin(kj::mv(cancelPaf.promise)) - .detach([](kj::Exception&&) {}); // ignore exceptions +private: + uint64_t interfaceId; + uint16_t methodId; + ClientHook::CallHints hints; + kj::Own client; + + RemotePromise sendImpl(bool isStreaming) { + KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request."); + + auto context = kj::refcounted(kj::mv(message), client->addRef(), hints, isStreaming); + auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context), hints); // Now the other branch returns the response from the context. - auto promise = forked.addBranch().then(kj::mvCapture(context, - [](kj::Own&& context) { + auto promise = promiseAndPipeline.promise.then([context=kj::mv(context)]() mutable { // force response allocation auto reader = context->getResults(MessageSize { 0, 0 }).asReader(); @@ -258,29 +291,12 @@ public: } else { return kj::mv(KJ_ASSERT_NONNULL(context->response)); } - })); + }); // We return the other branch. return RemotePromise( kj::mv(promise), AnyPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline))); } - - kj::Promise sendStreaming() override { - // We don't do any special handling of streaming in RequestHook for local requests, because - // there is no latency to compensate for between the client and server in this case. - return send().ignoreResult(); - } - - const void* getBrand() override { - return nullptr; - } - - kj::Own message; - -private: - uint64_t interfaceId; - uint16_t methodId; - kj::Own client; }; // ======================================================================================= @@ -385,65 +401,47 @@ public: promiseForClientResolution(promise.addBranch().fork()) {} Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { auto hook = kj::heap( - interfaceId, methodId, sizeHint, kj::addRef(*this)); + interfaceId, methodId, sizeHint, hints, kj::addRef(*this)); auto root = hook->message->getRoot(); return Request(root, kj::mv(hook)); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { - // This is a bit complicated. We need to initiate this call later on. When we initiate the - // call, we'll get a void promise for its completion and a pipeline object. Right now, we have - // to produce a similar void promise and pipeline that will eventually be chained to those. - // The problem is, these are two independent objects, but they both depend on the result of - // one future call. - // - // So, we need to set up a continuation that will initiate the call later, then we need to - // fork the promise for that continuation in order to send the completion promise and the - // pipeline to their respective places. - // - // TODO(perf): Too much reference counting? Can we do better? Maybe a way to fork - // Promise> into Tuple, Promise>? - - struct CallResultHolder: public kj::Refcounted { - // Essentially acts as a refcounted \VoidPromiseAndPipeline, so that we can create a promise - // for it and fork that promise. - - VoidPromiseAndPipeline content; - // One branch of the fork will use content.promise, the other branch will use - // content.pipeline. Neither branch will touch the other's piece. - - inline CallResultHolder(VoidPromiseAndPipeline&& content): content(kj::mv(content)) {} - - kj::Own addRef() { return kj::addRef(*this); } - }; - - // Create a promise for the call initiation. - kj::ForkedPromise> callResultPromise = - promiseForCallForwarding.addBranch().then(kj::mvCapture(context, - [=](kj::Own&& context, kj::Own&& client){ - return kj::refcounted( - client->call(interfaceId, methodId, kj::mv(context))); - })).fork(); - - // Create a promise that extracts the pipeline from the call initiation, and construct our - // QueuedPipeline to chain to it. - auto pipelinePromise = callResultPromise.addBranch().then( - [](kj::Own&& callResult){ - return kj::mv(callResult->content.pipeline); - }); - auto pipeline = kj::refcounted(kj::mv(pipelinePromise)); + kj::Own&& context, CallHints hints) override { + if (hints.noPromisePipelining) { + // Optimize for no pipelining. + auto promise = promiseForCallForwarding.addBranch() + .then([=,context=kj::mv(context)](kj::Own&& client) mutable { + return client->call(interfaceId, methodId, kj::mv(context), hints).promise; + }); + return VoidPromiseAndPipeline { kj::mv(promise), getDisabledPipeline() }; + } else if (hints.onlyPromisePipeline) { + auto pipelinePromise = promiseForCallForwarding.addBranch() + .then([=,context=kj::mv(context)](kj::Own&& client) mutable { + return client->call(interfaceId, methodId, kj::mv(context), hints).pipeline; + }); + return VoidPromiseAndPipeline { + kj::NEVER_DONE, + kj::refcounted(kj::mv(pipelinePromise)) + }; + } else { + auto split = promiseForCallForwarding.addBranch() + .then([=,context=kj::mv(context)](kj::Own&& client) mutable { + auto vpap = client->call(interfaceId, methodId, kj::mv(context), hints); + return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline)); + }).split(); - // Create a promise that simply chains to the void promise produced by the call initiation. - auto completionPromise = callResultPromise.addBranch().then( - [](kj::Own&& callResult){ - return kj::mv(callResult->content.promise); - }); + kj::Promise completionPromise = kj::mv(kj::get<0>(split)); + kj::Promise> pipelinePromise = kj::mv(kj::get<1>(split)); - // OK, now we can actually return our thing. - return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::mv(pipeline) }; + auto pipeline = kj::refcounted(kj::mv(pipelinePromise)); + + // OK, now we can actually return our thing. + return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::mv(pipeline) }; + } } kj::Maybe getResolved() override { @@ -542,46 +540,62 @@ private: class LocalClient final: public ClientHook, public kj::Refcounted { public: - LocalClient(kj::Own&& serverParam) - : server(kj::mv(serverParam)) { - server->thisHook = this; - startResolveTask(); + LocalClient(kj::Own&& serverParam, bool revocable = false) { + auto& serverRef = *server.emplace(kj::mv(serverParam)); + serverRef.thisHook = this; + if (revocable) revoker.emplace(); + startResolveTask(serverRef); } LocalClient(kj::Own&& serverParam, - _::CapabilityServerSetBase& capServerSet, void* ptr) - : server(kj::mv(serverParam)), capServerSet(&capServerSet), ptr(ptr) { - server->thisHook = this; - startResolveTask(); + _::CapabilityServerSetBase& capServerSet, void* ptr, + bool revocable = false) + : capServerSet(&capServerSet), ptr(ptr) { + auto& serverRef = *server.emplace(kj::mv(serverParam)); + serverRef.thisHook = this; + if (revocable) revoker.emplace(); + startResolveTask(serverRef); } ~LocalClient() noexcept(false) { - server->thisHook = nullptr; + KJ_IF_MAYBE(s, server) { + s->get()->thisHook = nullptr; + } + } + + void revoke(kj::Exception&& e) { + KJ_IF_MAYBE(s, server) { + KJ_ASSERT_NONNULL(revoker).cancel(e); + brokenException = kj::mv(e); + s->get()->thisHook = nullptr; + server = nullptr; + } } Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { KJ_IF_MAYBE(r, resolved) { // We resolved to a shortened path. New calls MUST go directly to the replacement capability // so that their ordering is consistent with callers who call getResolved() to get direct // access to the new capability. In particular it's important that we don't place these calls // in our streaming queue. - return r->get()->newCall(interfaceId, methodId, sizeHint); + return r->get()->newCall(interfaceId, methodId, sizeHint, hints); } auto hook = kj::heap( - interfaceId, methodId, sizeHint, kj::addRef(*this)); + interfaceId, methodId, sizeHint, hints, kj::addRef(*this)); auto root = hook->message->getRoot(); return Request(root, kj::mv(hook)); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { + kj::Own&& context, CallHints hints) override { KJ_IF_MAYBE(r, resolved) { // We resolved to a shortened path. New calls MUST go directly to the replacement capability // so that their ordering is consistent with callers who call getResolved() to get direct // access to the new capability. In particular it's important that we don't place these calls // in our streaming queue. - return r->get()->call(interfaceId, methodId, kj::mv(context)); + return r->get()->call(interfaceId, methodId, kj::mv(context), hints); } auto contextPtr = context.get(); @@ -603,23 +617,50 @@ public: } }).attach(kj::addRef(*this)); - // We have to fork this promise for the pipeline to receive a copy of the answer. - auto forked = promise.fork(); + if (hints.noPromisePipelining) { + // No need to set up pipelining.. + + // Make sure we release the params on return, since we would on the normal pipelining path. + // TODO(perf): Experiment with whether this is actually useful. It seems likely the params + // will be released soon anyway, so maybe this is a waste? + promise = promise.then([context=kj::mv(context)]() mutable { + context->releaseParams(); + }); + + // When we do apply pipelining, the use of `.fork()` has the side effect of eagerly + // evaluating the promise. To match the behavior here, we use `.eagerlyEvaluate()`. + // TODO(perf): Maybe we don't need to match behavior? It did break some tests but arguably + // those tests are weird and not what a real program would do... + promise = promise.eagerlyEvaluate(nullptr); + return VoidPromiseAndPipeline { kj::mv(promise), getDisabledPipeline() }; + } + + kj::Promise completionPromise = nullptr; + kj::Promise pipelineBranch = nullptr; + + if (hints.onlyPromisePipeline) { + pipelineBranch = kj::mv(promise); + completionPromise = kj::NEVER_DONE; + } else { + // We have to fork this promise for the pipeline to receive a copy of the answer. + auto forked = promise.fork(); + pipelineBranch = forked.addBranch(); + completionPromise = forked.addBranch().attach(context->addRef()); + } - auto pipelinePromise = forked.addBranch().then(kj::mvCapture(context->addRef(), - [=](kj::Own&& context) -> kj::Own { + auto pipelinePromise = pipelineBranch + .then([=,context=context->addRef()]() mutable -> kj::Own { context->releaseParams(); return kj::refcounted(kj::mv(context)); - })); + }); - auto tailPipelinePromise = context->onTailCall().then([](AnyPointer::Pipeline&& pipeline) { + auto tailPipelinePromise = context->onTailCall() + .then([context = context->addRef()](AnyPointer::Pipeline&& pipeline) { return kj::mv(pipeline.hook); }); pipelinePromise = pipelinePromise.exclusiveJoin(kj::mv(tailPipelinePromise)); - auto completionPromise = forked.addBranch().attach(kj::mv(context)); - return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::refcounted(kj::mv(pipelinePromise)) }; } @@ -688,19 +729,30 @@ public: } kj::Maybe getFd() override { - return server->getFd(); + KJ_IF_MAYBE(s, server) { + return s->get()->getFd(); + } else { + return nullptr; + } } private: - kj::Own server; + kj::Maybe> server; _::CapabilityServerSetBase* capServerSet = nullptr; void* ptr = nullptr; kj::Maybe> resolveTask; kj::Maybe> resolved; - void startResolveTask() { - resolveTask = server->shortenPath().map([this](kj::Promise promise) { + kj::Maybe revoker; + // If non-null, all promises must be wrapped in this revoker. + + void startResolveTask(Capability::Server& serverRef) { + resolveTask = serverRef.shortenPath().map([this](kj::Promise promise) { + KJ_IF_MAYBE(r, revoker) { + promise = r->wrap(kj::mv(promise)); + } + return promise.then([this](Capability::Client&& cap) { auto hook = ClientHook::from(kj::mv(cap)); @@ -816,8 +868,24 @@ private: return kj::cp(*e); } - auto result = server->dispatchCall(interfaceId, methodId, + // `server` can't be null here since `brokenException` is null. + auto result = KJ_ASSERT_NONNULL(server)->dispatchCall(interfaceId, methodId, CallContext(context)); + + KJ_IF_MAYBE(r, revoker) { + result.promise = r->wrap(kj::mv(result.promise)); + } + + if (!result.allowCancellation) { + // Make sure this call cannot be canceled by forking the promise and detaching one branch. + auto fork = result.promise.attach(kj::addRef(*this), context.addRef()).fork(); + result.promise = fork.addBranch(); + fork.addBranch().detach([](kj::Exception&&) { + // Exception from canceled call is silently discarded. The caller should have waited for + // it if they cared. + }); + } + if (result.isStreaming) { return result.promise .catch_([this](kj::Exception&& e) { @@ -836,6 +904,19 @@ kj::Own Capability::Client::makeLocalClient(kj::Own(kj::mv(server)); } +kj::Own Capability::Client::makeRevocableLocalClient(Capability::Server& server) { + auto result = kj::refcounted( + kj::Own(&server, kj::NullDisposer::instance), true /* revocable */); + return result; +} +void Capability::Client::revokeLocalClient(ClientHook& hook) { + revokeLocalClient(hook, KJ_EXCEPTION(FAILED, + "capability was revoked (RevocableServer was destroyed)")); +} +void Capability::Client::revokeLocalClient(ClientHook& hook, kj::Exception&& e) { + kj::downcast(hook).revoke(kj::mv(e)); +} + kj::Own newLocalPromiseClient(kj::Promise>&& promise) { return kj::refcounted(kj::mv(promise)); } @@ -906,6 +987,10 @@ public: return kj::cp(exception); } + AnyPointer::Pipeline sendForPipeline() override { + return AnyPointer::Pipeline(kj::refcounted(exception)); + } + const void* getBrand() override { return nullptr; } @@ -923,12 +1008,13 @@ public: resolved(resolved), brand(brand) {} Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { return newBrokenRequest(kj::cp(exception), sizeHint); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { + kj::Own&& context, CallHints hints) override { return VoidPromiseAndPipeline { kj::cp(exception), kj::refcounted(exception) }; } @@ -993,6 +1079,27 @@ Request newBrokenRequest( return Request(root, kj::mv(hook)); } +kj::Own getDisabledPipeline() { + class DisabledPipelineHook final: public PipelineHook { + public: + kj::Own addRef() override { + return kj::Own(this, kj::NullDisposer::instance); + } + + kj::Own getPipelinedCap(kj::ArrayPtr ops) override { + return newBrokenCap(KJ_EXCEPTION(FAILED, + "caller specified noPromisePipelining hint, but then tried to pipeline")); + } + + kj::Own getPipelinedCap(kj::Array&& ops) override { + return newBrokenCap(KJ_EXCEPTION(FAILED, + "caller specified noPromisePipelining hint, but then tried to pipeline")); + } + }; + static DisabledPipelineHook instance; + return instance.addRef(); +} + // ======================================================================================= ReaderCapabilityTable::ReaderCapabilityTable( diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/capability.h b/libs/EXTERNAL/capnproto/c++/src/capnp/capability.h index 125d2ccf252..1e71840ac5c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/capability.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/capability.h @@ -101,6 +101,8 @@ class RequestHook; class ResponseHook; class PipelineHook; class ClientHook; +template +class RevocableServer; template class Request: public Params::Builder { @@ -119,6 +121,24 @@ class Request: public Params::Builder { RemotePromise send() KJ_WARN_UNUSED_RESULT; // Send the call and return a promise for the results. + typename Results::Pipeline sendForPipeline(); + // Send the call in pipeline-only mode. The returned object can be used to make pipelined calls, + // but there is no way to wait for the completion of the original call. This allows some + // bookkeeping to be skipped under the hood, saving some time. + // + // Generally, this method should only be used when the caller will immediately make one or more + // pipelined calls on the result, and then throw away the pipeline and all pipelined + // capabilities. Other uses may run into caveats, such as: + // - Normally, calling `whenResolved()` on a pipelined capability would wait for the original RPC + // to complete (and possibly other things, if that RPC itself returned a promise capability), + // but when using `sendPipelineOnly()`, `whenResolved()` may complete immediately, or never, or + // at an arbitrary time. Do not rely on it. + // - Normal path shortening may not work with these capabilities. For exmaple, if the caller + // forwards a pipelined capability back to the callee's vat, calls made by the callee to that + // capability may continue to proxy through the caller. Conversely, if the callee ends up + // returning a capability that points back to the caller's vat, calls on the pipelined + // capability may continue to proxy through the callee. + private: kj::Own hook; @@ -225,9 +245,25 @@ class Capability::Client { // where no calls are being made. There is no reason to wait for this before making calls; if // the capability does not resolve, the call results will propagate the error. + struct CallHints { + bool noPromisePipelining = false; + // Hints that the pipeline part of the VoidPromiseAndPipeline won't be used, so it can be + // a bogus object. + + bool onlyPromisePipeline = false; + // Hints that the promise part of the VoidPromiseAndPipeline won't be used, so it can be a + // bogus promise. + // + // This hint is primarily intended to be passed to `ClientHook::call()`. When using + // `ClientHook::newCall()`, you would instead indicate the hint by calling the `ResponseHook`'s + // `sendForPipeline()` method. The effect of setting `onlyPromisePipeline = true` when invoking + // `ClientHook::newCall()` is unspecified; it might cause the returned `Request` to support + // only pipelining even when `send()` is called, or it might not. + }; + Request typelessRequest( uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint); + kj::Maybe sizeHint, CallHints hints); // Make a request without knowing the types of the params or results. You specify the type ID // and method number manually. @@ -251,15 +287,18 @@ class Capability::Client { template Request newCall(uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint); + kj::Maybe sizeHint, CallHints hints); template StreamingRequest newStreamingCall(uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint); + kj::Maybe sizeHint, CallHints hints); private: kj::Own hook; static kj::Own makeLocalClient(kj::Own&& server); + static kj::Own makeRevocableLocalClient(Capability::Server& server); + static void revokeLocalClient(ClientHook& hook); + static void revokeLocalClient(ClientHook& hook, kj::Exception&& reason); template friend struct _::PointerHelpers; @@ -271,6 +310,8 @@ class Capability::Client { friend struct List; friend class _::CapabilityServerSetBase; friend class ClientHook; + template + friend class RevocableServer; }; // ======================================================================================= @@ -366,6 +407,11 @@ class CallContext: public kj::DisallowConstCopy { // Note: This method has an overload that takes an lvalue reference for convenience. This // overload increments the refcount on the underlying PipelineHook -- it does not keep the // reference. + // + // Note: Capabilities returned by the replacement pipeline MUST either be exactly the same + // capabilities as in the final response, or eventually resolve to exactly the same + // capabilities, where "exactly the same" means the underlying `ClientHook` object is exactly + // the same object by identity. Resolving to some "equivalent" capability is not good enough. template kj::Promise tailCall(Request&& tailRequest); @@ -380,33 +426,20 @@ class CallContext: public kj::DisallowConstCopy { // In general, this should be the last thing a method implementation calls, and the promise // returned from `tailCall()` should then be returned by the method implementation. - void allowCancellation(); - // Indicate that it is OK for the RPC system to discard its Promise for this call's result if - // the caller cancels the call, thereby transitively canceling any asynchronous operations the - // call implementation was performing. This is not done by default because it could represent a - // security risk: applications must be carefully written to ensure that they do not end up in - // a bad state if an operation is canceled at an arbitrary point. However, for long-running - // method calls that hold significant resources, prompt cancellation is often useful. - // - // Keep in mind that asynchronous cancellation cannot occur while the method is synchronously - // executing on a local thread. The method must perform an asynchronous operation or call - // `EventLoop::current().evalLater()` to yield control. - // - // Note: You might think that we should offer `onCancel()` and/or `isCanceled()` methods that - // provide notification when the caller cancels the request without forcefully killing off the - // promise chain. Unfortunately, this composes poorly with promise forking: the canceled - // path may be just one branch of a fork of the result promise. The other branches still want - // the call to continue. Promise forking is used within the Cap'n Proto implementation -- in - // particular each pipelined call forks the result promise. So, if a caller made a pipelined - // call and then dropped the original object, the call should not be canceled, but it would be - // excessively complicated for the framework to avoid notififying of cancellation as long as - // pipelined calls still exist. + void allowCancellation() + KJ_UNAVAILABLE( + "As of Cap'n Proto 1.0, allowCancellation must be applied statically using an " + "annotation in the schema. See annotations defined in /capnp/c++.capnp. For " + "DynamicCapability::Server, use the constructor option (the annotation does not apply " + "to DynamicCapability). This change was made to gain a significant performance boost -- " + "dynamically allowing cancellation required excessive bookkeeping."); private: CallContextHook* hook; friend class Capability::Server; friend struct DynamicCapability; + friend class CallContextHook; }; template @@ -424,13 +457,20 @@ class StreamingCallContext: public kj::DisallowConstCopy { // - It wouldn't be particularly useful since streaming calls don't return anything, and they // already compensate for latency. - void allowCancellation(); + void allowCancellation() + KJ_UNAVAILABLE( + "As of Cap'n Proto 1.0, allowCancellation must be applied statically using an " + "annotation in the schema. See annotations defined in /capnp/c++.capnp. For " + "DynamicCapability::Server, use the constructor option (the annotation does not apply " + "to DynamicCapability). This change was made to gain a significant performance boost -- " + "dynamically allowing cancellation required excessive bookkeeping."); private: CallContextHook* hook; friend class Capability::Server; friend struct DynamicCapability; + friend class CallContextHook; }; class Capability::Server { @@ -449,13 +489,20 @@ class Capability::Server { // If true, this method was declared as `-> stream;`. No other calls should be permitted until // this call finishes, and if this call throws an exception, all future calls will throw the // same exception. + + bool allowCancellation = false; + // If true, the call can be canceled normally. If false, the immediate caller is responsible + // for ensuring that cancellation is prevented and that `context` remains valid until the + // call completes normally. + // + // See the `allowCancellation` annotation defined in `c++.capnp`. }; virtual DispatchCallResult dispatchCall(uint64_t interfaceId, uint16_t methodId, CallContext context) = 0; // Call the given method. `params` is the input struct, and should be released as soon as it - // is no longer needed. `context` may be used to allocate the output struct and deal with - // cancellation. + // is no longer needed. `context` may be used to allocate the output struct and other call + // logistics. virtual kj::Maybe getFd() { return nullptr; } // If this capability is backed by a file descriptor that is safe to directly expose to clients, @@ -491,6 +538,7 @@ class Capability::Server { // the server's constructor.) // - The capability client pointing at this object has been destroyed. (This is always the case // in the server's destructor.) + // - The capability client pointing at this object has been revoked using RevocableServer. // - Multiple capability clients have been created around the same server (possible if the server // is refcounted, which is not recommended since the client itself provides refcounting). @@ -512,6 +560,39 @@ class Capability::Server { friend class LocalClient; }; +template +class RevocableServer { + // Allows you to create a capability client pointing to a capability server without taking + // ownership of the server. When `RevocableServer` is destroyed, all clients created through it + // will become broken. All outstanding RPCs via those clients will be canceled and all future + // RPCs will immediately throw. Hence, once the `RevocableServer` is destroyed, it is safe + // to destroy the server object it referenced. + // + // This is particularly useful when you want to create a capability server that points to an + // object that you do not own, and thus cannot keep alive beyond some defined lifetime. Since + // you cannot force the client to respect lifetime rules, you should use a RevocableServer to + // revoke access before the lifetime ends. + // + // The RevocableServer object can be moved (as long as the server outlives it). + +public: + RevocableServer(typename T::Server& server); + RevocableServer(RevocableServer&&) = default; + RevocableServer& operator=(RevocableServer&&) = default; + ~RevocableServer() noexcept(false); + KJ_DISALLOW_COPY(RevocableServer); + + typename T::Client getClient(); + + void revoke(); + void revoke(kj::Exception&& reason); + // Revokes the capability immediately, rather than waiting for the destructor. This can also + // be used to specify a custom exception to use when revoking. + +private: + kj::Own hook; +}; + // ======================================================================================= template @@ -558,7 +639,7 @@ class ReaderCapabilityTable: private _::CapTableReader { public: explicit ReaderCapabilityTable(kj::Array>> table); - KJ_DISALLOW_COPY(ReaderCapabilityTable); + KJ_DISALLOW_COPY_AND_MOVE(ReaderCapabilityTable); template T imbue(T reader); @@ -579,7 +660,7 @@ class BuilderCapabilityTable: private _::CapTableBuilder { public: BuilderCapabilityTable(); - KJ_DISALLOW_COPY(BuilderCapabilityTable); + KJ_DISALLOW_COPY_AND_MOVE(BuilderCapabilityTable); inline kj::ArrayPtr>> getTable() { return table; } @@ -624,7 +705,7 @@ class CapabilityServerSet: private _::CapabilityServerSetBase { public: CapabilityServerSet() = default; - KJ_DISALLOW_COPY(CapabilityServerSet); + KJ_DISALLOW_COPY_AND_MOVE(CapabilityServerSet); typename T::Client add(kj::Own&& server); // Create a new capability Client for the given Server and also add this server to the set. @@ -651,6 +732,9 @@ class RequestHook { virtual kj::Promise sendStreaming() = 0; // Send a streaming call. + virtual AnyPointer::Pipeline sendForPipeline() = 0; + // Send a call for pipelining purposes only. + virtual const void* getBrand() = 0; // Returns a void* that identifies who made this request. This can be used by an RPC adapter to // discover when tail call is going to be sent over its own connection and therefore can be @@ -684,8 +768,11 @@ class ClientHook { public: ClientHook(); + using CallHints = Capability::Client::CallHints; + virtual Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) = 0; + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) = 0; // Start a new call, allowing the client to allocate request/response objects as it sees fit. // This version is used when calls are made from application code in the local process. @@ -695,17 +782,13 @@ class ClientHook { }; virtual VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) = 0; + kj::Own&& context, CallHints hints) = 0; // Call the object, but the caller controls allocation of the request/response objects. If the // callee insists on allocating these objects itself, it must make a copy. This version is used // when calls come in over the network via an RPC system. Note that even if the returned // `Promise` is discarded, the call may continue executing if any pipelined calls are // waiting for it. // - // Since the caller of this method chooses the CallContext implementation, it is the caller's - // responsibility to ensure that the returned promise is not canceled unless allowed via - // the context's `allowCancellation()`. - // // The call must not begin synchronously; the callee must arrange for the call to begin in a // later turn of the event loop. Otherwise, application code may call back and affect the // callee's state in an unexpected way. @@ -715,12 +798,22 @@ class ClientHook { // of the capability. The caller may permanently replace this client with the resolved one if // desired. Returns null if the client isn't a promise or hasn't resolved yet -- use // `whenMoreResolved()` to distinguish between them. + // + // Once a particular ClientHook's `getResolved()` returns non-null, it must permanently return + // exactly the same resolution. This is why `getResolved()` returns a reference -- it is assumed + // this object must have a strong reference to the resolution which it intends to keep + // permanently, therefore the returned reference will live at least as long as this `ClientHook`. + // This "only one resolution" policy is necessary for the RPC system to implement embargoes + // properly. virtual kj::Maybe>> whenMoreResolved() = 0; // If this client is a settled reference (not a promise), return nullptr. Otherwise, return a // promise that eventually resolves to a new client that is closer to being the final, settled // client (i.e. the value eventually returned by `getResolved()`). Calling this repeatedly // should eventually produce a settled client. + // + // Once the promise resolves, `getResolved()` must return exactly the same `ClientHook` as the + // one this Promise resolved to. kj::Promise whenResolved(); // Repeatedly calls whenMoreResolved() until it returns nullptr. @@ -751,6 +844,12 @@ class ClientHook { static kj::Own from(Capability::Client client) { return kj::mv(client.hook); } }; +class RevocableClientHook: public ClientHook { +public: + virtual void revoke() = 0; + virtual void revoke(kj::Exception&& reason) = 0; +}; + class CallContextHook { // Hook interface implemented by RPC system to manage a call on the server side. See // CallContext. @@ -760,7 +859,6 @@ class CallContextHook { virtual void releaseParams() = 0; virtual AnyPointer::Builder getResults(kj::Maybe sizeHint) = 0; virtual kj::Promise tailCall(kj::Own&& request) = 0; - virtual void allowCancellation() = 0; virtual void setPipeline(kj::Own&& pipeline) = 0; @@ -774,6 +872,11 @@ class CallContextHook { // promise fulfiller for onTailCall() with the returned pipeline. virtual kj::Own addRef() = 0; + + template + static CallContextHook& from(CallContext& context) { return *context.hook; } + template + static CallContextHook& from(StreamingCallContext& context) { return *context.hook; } }; kj::Own newLocalPromiseClient(kj::Promise>&& promise); @@ -796,6 +899,11 @@ Request newBrokenRequest( kj::Exception&& reason, kj::Maybe sizeHint); // Helper function that creates a Request object that simply throws exceptions when sent. +kj::Own getDisabledPipeline(); +// Gets a PipelineHook appropriate to use when CallHints::noPromisePipelining is true. This will +// throw from all calls. This does not actually allocate the object; a static global object is +// returned with a null disposer. + // ======================================================================================= // Extend PointerHelpers for interfaces @@ -964,6 +1072,13 @@ RemotePromise Request::send() { return RemotePromise(kj::mv(typedPromise), kj::mv(typedPipeline)); } +template +typename Results::Pipeline Request::sendForPipeline() { + auto typelessPipeline = hook->sendForPipeline(); + hook = nullptr; // prevent reuse + return typename Results::Pipeline(kj::mv(typelessPipeline)); +} + template kj::Promise StreamingRequest::send() { auto promise = hook->sendStreaming(); @@ -989,19 +1104,19 @@ inline typename T::Client Capability::Client::castAs() { } inline Request Capability::Client::typelessRequest( uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint) { - return newCall(interfaceId, methodId, sizeHint); + kj::Maybe sizeHint, CallHints hints) { + return newCall(interfaceId, methodId, sizeHint, hints); } template inline Request Capability::Client::newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) { - auto typeless = hook->newCall(interfaceId, methodId, sizeHint); + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, CallHints hints) { + auto typeless = hook->newCall(interfaceId, methodId, sizeHint, hints); return Request(typeless.template getAs(), kj::mv(typeless.hook)); } template inline StreamingRequest Capability::Client::newStreamingCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) { - auto typeless = hook->newCall(interfaceId, methodId, sizeHint); + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, CallHints hints) { + auto typeless = hook->newCall(interfaceId, methodId, sizeHint, hints); return StreamingRequest(typeless.template getAs(), kj::mv(typeless.hook)); } @@ -1064,14 +1179,6 @@ inline kj::Promise CallContext::tailCall( Request&& tailRequest) { return hook->tailCall(kj::mv(tailRequest.hook)); } -template -inline void CallContext::allowCancellation() { - hook->allowCancellation(); -} -template -inline void StreamingCallContext::allowCancellation() { - hook->allowCancellation(); -} template CallContext Capability::Server::internalGetTypedContext( @@ -1089,6 +1196,31 @@ Capability::Client Capability::Server::thisCap() { return Client(thisHook->addRef()); } +template +RevocableServer::RevocableServer(typename T::Server& server) + : hook(Capability::Client::makeRevocableLocalClient(server)) {} +template +RevocableServer::~RevocableServer() noexcept(false) { + // Check if moved away. + if (hook.get() != nullptr) { + Capability::Client::revokeLocalClient(*hook); + } +} + +template +typename T::Client RevocableServer::getClient() { + return typename T::Client(hook->addRef()); +} + +template +void RevocableServer::revoke() { + Capability::Client::revokeLocalClient(*hook); +} +template +void RevocableServer::revoke(kj::Exception&& exception) { + Capability::Client::revokeLocalClient(*hook, kj::mv(exception)); +} + namespace _ { // private struct PipelineBuilderPair { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/cc_capnp_library.bzl b/libs/EXTERNAL/capnproto/c++/src/capnp/cc_capnp_library.bzl new file mode 100644 index 00000000000..9e4acd35b9b --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/cc_capnp_library.bzl @@ -0,0 +1,128 @@ +"""Bazel rule to compile .capnp files into c++.""" + +capnp_provider = provider("Capnproto Provider", fields = { + "includes": "includes for this target (transitive)", + "inputs": "src + data for the target", + "src_prefix": "src_prefix of the target", +}) + +def _workspace_path(label, path): + if label.workspace_root == "": + return path + return label.workspace_root + "/" + path + +def _capnp_gen_impl(ctx): + label = ctx.label + src_prefix = _workspace_path(label, ctx.attr.src_prefix) if ctx.attr.src_prefix != "" else "" + includes = [] + + inputs = ctx.files.srcs + ctx.files.data + for dep_target in ctx.attr.deps: + includes += dep_target[capnp_provider].includes + inputs += dep_target[capnp_provider].inputs + + if src_prefix != "": + includes.append(src_prefix) + + system_include = ctx.files._capnp_system[0].dirname.removesuffix("/capnp") + + gen_dir = ctx.var["GENDIR"] + out_dir = gen_dir + if src_prefix != "": + out_dir = out_dir + "/" + src_prefix + + cc_out = "-o%s:%s" % (ctx.executable._capnpc_cxx.path, out_dir) + args = ctx.actions.args() + args.add_all(["compile", "--verbose", cc_out]) + args.add_all(["-I" + inc for inc in includes]) + args.add_all(["-I", system_include]) + + if src_prefix == "": + # guess src_prefix for generated files + for src in ctx.files.srcs: + if src.path.startswith(gen_dir): + src_prefix = gen_dir + break + + if src_prefix != "": + args.add_all(["--src-prefix", src_prefix]) + + args.add_all([s for s in ctx.files.srcs]) + + ctx.actions.run( + inputs = inputs + ctx.files._capnpc_cxx + ctx.files._capnpc_capnp + ctx.files._capnp_system, + outputs = ctx.outputs.outs, + executable = ctx.executable._capnpc, + arguments = [args], + mnemonic = "GenCapnp", + ) + + return [ + capnp_provider( + includes = includes, + inputs = inputs, + src_prefix = src_prefix, + ), + ] + +_capnp_gen = rule( + attrs = { + "srcs": attr.label_list(allow_files = True), + "deps": attr.label_list(providers = [capnp_provider]), + "data": attr.label_list(allow_files = True), + "outs": attr.output_list(), + "src_prefix": attr.string(), + "_capnpc": attr.label(executable = True, allow_single_file = True, cfg = "exec", default = "@capnp-cpp//src/capnp:capnp_tool"), + "_capnpc_cxx": attr.label(executable = True, allow_single_file = True, cfg = "exec", default = "@capnp-cpp//src/capnp:capnpc-c++"), + "_capnpc_capnp": attr.label(executable = True, allow_single_file = True, cfg = "exec", default = "@capnp-cpp//src/capnp:capnpc-capnp"), + "_capnp_system": attr.label(default = "@capnp-cpp//src/capnp:capnp_system_library"), + }, + output_to_genfiles = True, + implementation = _capnp_gen_impl, +) + +def cc_capnp_library( + name, + srcs = [], + data = [], + deps = [], + src_prefix = "", + visibility = None, + target_compatible_with = None, + **kwargs): + """Bazel rule to create a C++ capnproto library from capnp source files + + Args: + name: library name + srcs: list of files to compile + data: additional files to provide to the compiler - data files and includes that need not to + be compiled + deps: other cc_capnp_library rules to depend on + src_prefix: src_prefix for capnp compiler to the source root + visibility: rule visibility + target_compatible_with: target compatibility + **kwargs: rest of the arguments to cc_library rule + """ + + hdrs = [s + ".h" for s in srcs] + srcs_cpp = [s + ".c++" for s in srcs] + + _capnp_gen( + name = name + "_gen", + srcs = srcs, + deps = [s + "_gen" for s in deps], + data = data, + outs = hdrs + srcs_cpp, + src_prefix = src_prefix, + visibility = visibility, + target_compatible_with = target_compatible_with, + ) + native.cc_library( + name = name, + srcs = srcs_cpp, + hdrs = hdrs, + deps = deps + ["@capnp-cpp//src/capnp:capnp_runtime"], + visibility = visibility, + target_compatible_with = target_compatible_with, + **kwargs + ) diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/common.h b/libs/EXTERNAL/capnproto/c++/src/capnp/common.h index aece4e51808..77f0cb58ef8 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/common.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/common.h @@ -46,9 +46,9 @@ CAPNP_BEGIN_HEADER namespace capnp { -#define CAPNP_VERSION_MAJOR 0 -#define CAPNP_VERSION_MINOR 9 -#define CAPNP_VERSION_MICRO 1 +#define CAPNP_VERSION_MAJOR 1 +#define CAPNP_VERSION_MINOR 0 +#define CAPNP_VERSION_MICRO 2 #define CAPNP_VERSION \ (CAPNP_VERSION_MAJOR * 1000000 + CAPNP_VERSION_MINOR * 1000 + CAPNP_VERSION_MICRO) @@ -361,7 +361,7 @@ class word { // the copy constructor. We don't want to disable the warning because it's a useful warning and // we'd have to disable it for all applications that include this header. Instead we allow `word` // to be copyable on GCC. - KJ_DISALLOW_COPY(word); + KJ_DISALLOW_COPY_AND_MOVE(word); #endif }; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/BUILD.bazel b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/BUILD.bazel new file mode 100644 index 00000000000..fcaecfa3fa1 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/BUILD.bazel @@ -0,0 +1,98 @@ +load("@capnp-cpp//src/capnp:cc_capnp_library.bzl", "cc_capnp_library") + +exports_files([ + "json.capnp", +]) + +# because git contains generated artifacts (which are used to bootstrap the compiler) +# we can't have cc_capnp_library for json.capnp. Expose it as cc library and a file. +cc_library( + name = "json", + srcs = [ + "json.c++", + "json.capnp.c++", + ], + hdrs = [ + "json.capnp.h", + "json.h", + ], + include_prefix = "capnp/compat", + visibility = ["//visibility:public"], + deps = [ + "//src/capnp", + ], +) + +cc_capnp_library( + name = "http-over-capnp_capnp", + srcs = [ + "byte-stream.capnp", + "http-over-capnp.capnp", + ], + include_prefix = "capnp/compat", + src_prefix = "src", + visibility = ["//visibility:public"], +) + +cc_library( + name = "http-over-capnp", + srcs = [ + "byte-stream.c++", + "http-over-capnp.c++", + ], + hdrs = [ + "byte-stream.h", + "http-over-capnp.h", + ], + include_prefix = "capnp/compat", + visibility = ["//visibility:public"], + deps = [ + ":http-over-capnp_capnp", + "//src/kj/compat:kj-http", + ], +) + +cc_library( + name = "websocket-rpc", + srcs = [ + "websocket-rpc.c++", + ], + hdrs = [ + "websocket-rpc.h", + ], + include_prefix = "capnp/compat", + visibility = ["//visibility:public"], + deps = [ + "//src/capnp", + "//src/kj/compat:kj-http", + ], +) + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + deps = [ + ":websocket-rpc", + ":http-over-capnp", + "//src/capnp:capnp-test" + ], +) for f in [ + "byte-stream-test.c++", + "http-over-capnp-test.c++", + "websocket-rpc-test.c++", +]] + +cc_library( + name = "http-over-capnp-test-as-header", + hdrs = ["http-over-capnp-test.c++"], +) + +cc_test( + name = "http-over-capnp-old-test", + srcs = ["http-over-capnp-old-test.c++"], + deps = [ + ":http-over-capnp-test-as-header", + ":http-over-capnp", + "//src/capnp:capnp-test" + ], +) diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream-test.c++ index 297165a63cf..49fede1fec0 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream-test.c++ @@ -33,7 +33,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { auto buffer = kj::heapArray(expected.size()); auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); - return promise.then(kj::mvCapture(buffer, [&in,expected](kj::Array buffer, size_t amount) { + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { if (amount == 0) { KJ_FAIL_ASSERT("expected data never sent", expected); } @@ -44,7 +44,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { } return expectRead(in, expected.slice(amount)); - })); + }); } kj::String makeString(size_t size) { @@ -307,7 +307,7 @@ KJ_TEST("KJ -> ByteStream RPC -> KJ pipe -> ByteStream RPC -> KJ with shortening rpc::twoparty::Side::CLIENT); capnp::TwoPartyClient server(*rpcConnection.ends[1], serverFactory.kjToCapnp(kj::mv(middlePipe.out)), - rpc::twoparty::Side::CLIENT); + rpc::twoparty::Side::SERVER); auto backWrapped = serverFactory.capnpToKj(server.bootstrap().castAs()); auto midPumpPormise = middlePipe.in->pumpTo(*backWrapped, 3); @@ -377,7 +377,7 @@ KJ_TEST("KJ -> ByteStream RPC -> KJ pipe -> ByteStream RPC -> KJ with concurrent rpc::twoparty::Side::CLIENT); capnp::TwoPartyClient server(*rpcConnection.ends[1], serverFactory.kjToCapnp(kj::mv(middlePipe.out)), - rpc::twoparty::Side::CLIENT); + rpc::twoparty::Side::SERVER); auto backWrapped = serverFactory.capnpToKj(server.bootstrap().castAs()); auto midPumpPormise = middlePipe.in->pumpTo(*backWrapped); @@ -448,7 +448,7 @@ KJ_TEST("KJ -> KJ pipe -> ByteStream RPC -> KJ pipe -> ByteStream RPC -> KJ with rpc::twoparty::Side::CLIENT); capnp::TwoPartyClient server(*rpcConnection.ends[1], serverFactory.kjToCapnp(kj::mv(middlePipe.out)), - rpc::twoparty::Side::CLIENT); + rpc::twoparty::Side::SERVER); auto backWrapped = serverFactory.capnpToKj(server.bootstrap().castAs()); auto midPumpPormise = middlePipe.in->pumpTo(*backWrapped); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.c++ index c7b709dc54f..2602b4a9c52 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.c++ @@ -59,6 +59,9 @@ public: // the same ByteStreamFactory. Since destruction of a KJ stream signals EOF, we need to propagate // that by destroying our underlying stream. // TODO(cleanup): When KJ streams evolve an end() method, this can go away. + + virtual kj::Promise directExplicitEnd() = 0; + // Like directEnd() but used in cases where an explicit end() call actually was made. }; class ByteStreamFactory::SubstreamImpl final: public StreamServerBase { @@ -69,9 +72,10 @@ public: kj::AsyncOutputStream& stream, capnp::ByteStream::SubstreamCallback::Client callback, uint64_t limit, + kj::Maybe> tlsStarter, kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller()) : factory(factory), - state(Streaming {parent, kj::mv(ownParent), stream, kj::mv(callback)}), + state(Streaming {parent, kj::mv(ownParent), stream, kj::mv(callback), kj::mv(tlsStarter)}), limit(limit), resolveFulfiller(kj::mv(paf.fulfiller)), resolvePromise(paf.promise.fork()) {} @@ -133,6 +137,32 @@ public: } } + kj::Promise directExplicitEnd() override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(redirected, Redirected) { + // Ugh I guess we need to send a real end() request here. + return redirected.replacement.endRequest(MessageSize {2, 0}).send().ignoreResult(); + } + KJ_CASE_ONEOF(e, Ended) { + // whatever + return kj::READY_NOW; + } + KJ_CASE_ONEOF(b, Borrowed) { + // ... whatever. + return kj::READY_NOW; + } + KJ_CASE_ONEOF(streaming, Streaming) { + auto req = streaming.callback.endedRequest(MessageSize {4, 0}); + req.setByteCount(completed); + auto promise = req.send().ignoreResult(); + streaming.parent.returnStream(completed); + state = Ended(); + return promise; + } + } + KJ_UNREACHABLE; + } + // --------------------------------------------------------------------------- // implements ByteStream::Server RPC interface @@ -200,6 +230,10 @@ public: KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed"); } KJ_CASE_ONEOF(streaming, Streaming) { + // Revoke the TLS starter when stream is ended. This will ensure any startTls calls + // cannot be falsely invoked after the stream is destroyed. + auto drop = kj::mv(streaming.tlsStarter); + auto req = streaming.callback.endedRequest(MessageSize {4, 0}); req.setByteCount(completed); auto result = req.send().ignoreResult(); @@ -211,6 +245,10 @@ public: KJ_UNREACHABLE; } + kj::Promise startTls(StartTlsContext context) override { + KJ_UNIMPLEMENTED("A substream does not support TLS initiation"); + } + kj::Promise getSubstream(GetSubstreamContext context) override { KJ_SWITCH_ONEOF(state) { KJ_CASE_ONEOF(redirected, Redirected) { @@ -233,7 +271,8 @@ public: context.releaseParams(); auto results = context.getResults(MessageSize { 2, 1 }); results.setSubstream(factory.streamSet.add(kj::heap( - factory, *this, thisCap(), streaming.stream, kj::mv(callback), kj::mv(limit)))); + factory, *this, thisCap(), streaming.stream, kj::mv(callback), kj::mv(limit), + kj::mv(streaming.tlsStarter)))); state = Borrowed { kj::mv(streaming) }; return kj::READY_NOW; } @@ -249,6 +288,7 @@ private: capnp::ByteStream::Client ownParent; kj::AsyncOutputStream& stream; capnp::ByteStream::SubstreamCallback::Client callback; + kj::Maybe> tlsStarter; }; struct Borrowed { Streaming originalState; @@ -295,6 +335,15 @@ public: state.get>()->startProbing(); } + CapnpToKjStreamAdapter(ByteStreamFactory& factory, + kj::Own inner, + kj::Maybe> starter) + : factory(factory), + tlsStarter(kj::mv(starter)), + state(kj::heap(*this, kj::mv(inner))) { + state.get>()->startProbing(); + } + CapnpToKjStreamAdapter(ByteStreamFactory& factory, kj::Own pathProber) : factory(factory), @@ -360,6 +409,32 @@ public: } } + kj::Promise directExplicitEnd() override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + state = Ended(); + return kj::READY_NOW; + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + state = Ended(); + return kj::READY_NOW; + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + // Ugh I guess we need to send a real end() request here. + return capnpStream.endRequest(MessageSize {2, 0}).send().ignoreResult(); + } + KJ_CASE_ONEOF(b, Borrowed) { + // Fine, ignore. + return kj::READY_NOW; + } + KJ_CASE_ONEOF(e, Ended) { + // Fine, ignore. + return kj::READY_NOW; + } + } + KJ_UNREACHABLE; + } + // --------------------------------------------------------------------------- // PathProber @@ -418,7 +493,7 @@ public: // We already completed a path-shortening. Probably SubstreamCallbackImpl::ended() was // eventually called, meaning the substream was ended without redirecting back to us. So, // we're at EOF. - return uint64_t(0); + return kj::constPromise(); } } @@ -430,7 +505,7 @@ public: // works because pumps do not propagate EOF -- the destination can still receive further // writes and pumps. Basically our probing pump becomes a no-op, and then we revert to having // each write() RPC directly call write() on the inner stream. - return size_t(0); + return kj::constPromise(); } kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { @@ -442,7 +517,7 @@ public: return kj::mv(*promise); } else { // There is no shorter path. As with tryRead(), we pretend we get immediate EOF. - return uint64_t(0); + return kj::constPromise(); } } @@ -552,6 +627,10 @@ protected: } kj::Promise end(EndContext context) override { + // Revoke the TLS starter when stream is ended. This will ensure any startTls calls + // cannot be falsely invoked after the stream is destroyed. + auto drop = kj::mv(tlsStarter); + KJ_SWITCH_ONEOF(state) { KJ_CASE_ONEOF(prober, kj::Own) { return prober->whenReady().then([this, context]() mutable { @@ -582,6 +661,30 @@ protected: KJ_UNREACHABLE; } + kj::Promise startTls(StartTlsContext context) override { + auto params = context.getParams(); + KJ_IF_MAYBE(s, tlsStarter) { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + return KJ_ASSERT_NONNULL(*s->get())(params.getExpectedServerHostname()); + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + return KJ_ASSERT_NONNULL(*s->get())(params.getExpectedServerHostname()); + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + return KJ_ASSERT_NONNULL(*s->get())(params.getExpectedServerHostname()); + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("cannot call startTls on a bytestream that was ended"); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("can't call startTls while stream is borrowed"); + } + } + } + KJ_UNREACHABLE; + } + kj::Promise getSubstream(GetSubstreamContext context) override { KJ_SWITCH_ONEOF(state) { KJ_CASE_ONEOF(prober, kj::Own) { @@ -598,7 +701,8 @@ protected: auto results = context.initResults(MessageSize {2, 1}); results.setSubstream(factory.streamSet.add(kj::heap( - factory, *this, thisCap(), *kjStream, kj::mv(callback), kj::mv(limit)))); + factory, *this, thisCap(), *kjStream, kj::mv(callback), kj::mv(limit), + kj::mv(tlsStarter)))); state = Borrowed { kj::mv(kjStream) }; return kj::READY_NOW; } @@ -623,6 +727,7 @@ protected: private: ByteStreamFactory& factory; + kj::Maybe> tlsStarter; struct Borrowed { kj::Own stream; }; struct Ended {}; @@ -690,22 +795,36 @@ private: // ======================================================================================= -class ByteStreamFactory::KjToCapnpStreamAdapter final: public kj::AsyncOutputStream { +class ByteStreamFactory::KjToCapnpStreamAdapter final: public ExplicitEndOutputStream { public: - KjToCapnpStreamAdapter(ByteStreamFactory& factory, capnp::ByteStream::Client innerParam) + KjToCapnpStreamAdapter(ByteStreamFactory& factory, capnp::ByteStream::Client innerParam, + bool explicitEnd) : factory(factory), inner(kj::mv(innerParam)), - findShorterPathTask(findShorterPath(inner).fork()) {} + findShorterPathTask(findShorterPath(inner).fork()), + explicitEnd(explicitEnd) {} ~KjToCapnpStreamAdapter() noexcept(false) { - // HACK: KJ streams are implicitly ended on destruction, but the RPC stream needs a call. We - // use a detached promise for now, which is probably OK since capabilities are refcounted and - // asynchronously destroyed anyway. - // TODO(cleanup): Fix this when KJ streads add an explicit end() method. + if (!explicitEnd) { + // HACK: KJ streams are implicitly ended on destruction, but the RPC stream needs a call. We + // use a detached promise for now, which is probably OK since capabilities are refcounted and + // asynchronously destroyed anyway. + // TODO(cleanup): Fix this when KJ streads add an explicit end() method. + KJ_IF_MAYBE(o, optimized) { + o->directEnd(); + } else { + inner.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){}); + } + } + } + + kj::Promise end() override { + KJ_REQUIRE(explicitEnd, "not expecting explicit end"); + KJ_IF_MAYBE(o, optimized) { - o->directEnd(); + return o->directExplicitEnd(); } else { - inner.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){}); + return inner.endRequest(MessageSize {2, 0}).send().ignoreResult(); } } @@ -832,6 +951,9 @@ private: // possible. // 2. Implements whenWriteDisconnected(). + bool explicitEnd; + // Did the creator promise to explicitly call end()? + kj::Promise findShorterPath(capnp::ByteStream::Client& capnpClient) { // If the capnp stream turns out to resolve back to this process, shorten the path. // Also, implement whenWriteDisconnected() based on this. @@ -938,7 +1060,7 @@ private: } } KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { - // Pumping from some other kind of steram. Optimize the pump by reading from the input + // Pumping from some other kind of stream. Optimize the pump by reading from the input // directly into outgoing RPC messages. size_t size = kj::min(remaining, 8192); auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word) }); @@ -1022,8 +1144,19 @@ capnp::ByteStream::Client ByteStreamFactory::kjToCapnp(kj::Own(*this, kj::mv(kjStream))); } +capnp::ByteStream::Client ByteStreamFactory::kjToCapnp( + kj::Own kjStream, kj::Maybe> tlsStarter) { + return streamSet.add( + kj::heap(*this, kj::mv(kjStream), kj::mv(tlsStarter))); +} + kj::Own ByteStreamFactory::capnpToKj(capnp::ByteStream::Client capnpStream) { - return kj::heap(*this, kj::mv(capnpStream)); + return kj::heap(*this, kj::mv(capnpStream), false); +} + +kj::Own ByteStreamFactory::capnpToKjExplicitEnd( + capnp::ByteStream::Client capnpStream) { + return kj::heap(*this, kj::mv(capnpStream), true); } } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.capnp b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.capnp index b98d85e9fb2..3298c06f485 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.capnp +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.capnp @@ -1,6 +1,8 @@ @0x8f5d14e1c273738d; -$import "/capnp/c++.capnp".namespace("capnp"); +using Cxx = import "/capnp/c++.capnp"; +$Cxx.namespace("capnp"); +$Cxx.allowCancellation; interface ByteStream { write @0 (bytes :Data) -> stream; @@ -23,6 +25,10 @@ interface ByteStream { # While a substream is active, it is an error to call write() on the original stream. Doing so # may throw an exception or may arbitrarily interleave bytes with the substream's writes. + startTls @3 (expectedServerHostname :Text) -> stream; + # Client calls this method when it wants to initiate TLS. This ByteStream is not terminated, + # the caller should reuse it. + interface SubstreamCallback { ended @0 (byteCount :UInt64); # `end()` was called on the substream after writing `byteCount` bytes. The `end()` call was diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.h index 545e6e592b7..b34aa3f521a 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/byte-stream.h @@ -24,17 +24,43 @@ #include #include +#include + +CAPNP_BEGIN_HEADER namespace capnp { +class ExplicitEndOutputStream: public kj::AsyncOutputStream { + // HACK: KJ's AsyncOutputStream has a known serious design flaw in that EOF is signaled by + // destroying the stream object rather than by calling an explicit `end()` method. This causes + // some serious problems when signaling EOF requires doing additional I/O, such as when + // wrapping a capnp ByteStream where `end()` is an RPC call. + // + // When it really must, the ByteStream implementation will honor the KJ convention by starting + // the RPC in its destructor and detach()ing the promise. But, this has lots of negative side + // effects, especially in the case where the stream is really meant to be aborted abruptly. + // + // In lieu of an actual deep refactoring of KJ, ByteStreamFactory allows its caller to + // explicily specify when it is able to promise that it will make an explicit `end()` call. + // capnpToKjExplicitEnd() returns an ExplicitEndOutputStream, which expect to receive an + // `end()` call on clean EOF, and treats destruction without `end()` as an abort. This is used + // in particular within http-over-capnp to improve behavior somewhat. +public: + virtual kj::Promise end() = 0; +}; + class ByteStreamFactory { // In order to allow path-shortening through KJ, a common factory must be used for converting // between RPC ByteStreams and KJ streams. public: capnp::ByteStream::Client kjToCapnp(kj::Own kjStream); + capnp::ByteStream::Client kjToCapnp( + kj::Own kjStream, kj::Maybe> tlsStarter); kj::Own capnpToKj(capnp::ByteStream::Client capnpStream); + kj::Own capnpToKjExplicitEnd(capnp::ByteStream::Client capnpStream); + private: CapabilityServerSet streamSet; @@ -45,3 +71,5 @@ class ByteStreamFactory { }; } // namespace capnp + +CAPNP_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-old-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-old-test.c++ new file mode 100644 index 00000000000..9a5aea9b135 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-old-test.c++ @@ -0,0 +1,2 @@ +#define TEST_PEER_OPTIMIZATION_LEVEL HttpOverCapnpFactory::LEVEL_1 +#include "http-over-capnp-test.c++" diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-perf-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-perf-test.c++ new file mode 100644 index 00000000000..20ba63840b6 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-perf-test.c++ @@ -0,0 +1,446 @@ +// Copyright (c) 2022 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "http-over-capnp.h" +#include +#include +#include +#include +#include +#if KJ_BENCHMARK_MALLOC +#include +#endif + +#if KJ_HAS_COROUTINE + +namespace capnp { +namespace { + +// ======================================================================================= +// Metrics-gathering +// +// TODO(cleanup): Generalize for other benchmarks? + +static size_t globalMallocCount = 0; +static size_t globalMallocBytes = 0; + +#if KJ_BENCHMARK_MALLOC +// If KJ_BENCHMARK_MALLOC is defined then we are instructed to override malloc() in order to +// measure total allocations. We are careful only to define this when the build is set up in a +// way that this won't cause build failures (e.g., we must not be statically linking a malloc +// implementation). + +extern "C" { + +void* malloc(size_t size) { + typedef void* Malloc(size_t); + static Malloc* realMalloc = reinterpret_cast(dlsym(RTLD_NEXT, "malloc")); + + ++globalMallocCount; + globalMallocBytes += size; + return realMalloc(size); +} + +} // extern "C" + +#endif // KJ_BENCHMARK_MALLOC + +class Metrics { +public: + Metrics() + : startMallocCount(globalMallocCount), startMallocBytes(globalMallocBytes), + upBandwidth(0), downBandwidth(0), + clientReadCount(0), clientWriteCount(0), + serverReadCount(0), serverWriteCount(0) {} + ~Metrics() noexcept(false) { + #if KJ_BENCHMARK_MALLOC + size_t mallocCount = globalMallocCount - startMallocCount; + size_t mallocBytes = globalMallocBytes - startMallocBytes; + KJ_LOG(WARNING, mallocCount, mallocBytes); + #endif + + if (hadStreamPair) { + KJ_LOG(WARNING, upBandwidth, downBandwidth, + clientReadCount, clientWriteCount, serverReadCount, serverWriteCount); + } + } + + enum Side { CLIENT, SERVER }; + + class StreamWrapper final: public kj::AsyncIoStream { + // Wrap a stream and count metrics. + + public: + StreamWrapper(Metrics& metrics, kj::AsyncIoStream& inner, Side side) + : metrics(metrics), inner(inner), side(side) {} + + ~StreamWrapper() noexcept(false) { + switch (side) { + case CLIENT: + metrics.clientReadCount += readCount; + metrics.clientWriteCount += writeCount; + metrics.upBandwidth += writeBytes; + metrics.downBandwidth += readBytes; + break; + case SERVER: + metrics.serverReadCount += readCount; + metrics.serverWriteCount += writeCount; + break; + } + } + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner.read(buffer, minBytes, maxBytes) + .then([this](size_t n) { + ++readCount; + readBytes += n; + return n; + }); + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner.tryRead(buffer, minBytes, maxBytes) + .then([this](size_t n) { + ++readCount; + readBytes += n; + return n; + }); + } + + kj::Maybe tryGetLength() override { + return inner.tryGetLength(); + } + + kj::Promise write(const void* buffer, size_t size) override { + ++writeCount; + writeBytes += size; + return inner.write(buffer, size); + } + kj::Promise write(kj::ArrayPtr> pieces) override { + ++writeCount; + for (auto& piece: pieces) { + writeBytes += piece.size(); + } + return inner.write(pieces); + } + + kj::Promise pumpTo( + kj::AsyncOutputStream& output, uint64_t amount = kj::maxValue) override { + // Our benchmarks don't depend on this currently. If they do we need to think about how to + // apply it. + KJ_UNIMPLEMENTED("pump metrics"); + } + kj::Maybe> tryPumpFrom( + AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + // Our benchmarks don't depend on this currently. If they do we need to think about how to + // apply it. + KJ_UNIMPLEMENTED("pump metrics"); + } + + kj::Promise whenWriteDisconnected() override { + return inner.whenWriteDisconnected(); + } + + void shutdownWrite() override { + inner.shutdownWrite(); + } + void abortRead() override { + inner.abortRead(); + } + + private: + Metrics& metrics; + kj::AsyncIoStream& inner; + Side side; + + size_t readCount = 0; + size_t readBytes = 0; + size_t writeCount = 0; + size_t writeBytes = 0; + }; + + struct StreamPair { + kj::TwoWayPipe pipe; + StreamWrapper client; + StreamWrapper server; + + StreamPair(Metrics& metrics) + : pipe(kj::newTwoWayPipe()), + client(metrics, *pipe.ends[0], CLIENT), + server(metrics, *pipe.ends[1], SERVER) { + metrics.hadStreamPair = true; + } + }; + +private: + size_t startMallocCount KJ_UNUSED; + size_t startMallocBytes KJ_UNUSED; + size_t upBandwidth; + size_t downBandwidth; + size_t clientReadCount; + size_t clientWriteCount; + size_t serverReadCount; + size_t serverWriteCount; + + bool hadStreamPair = false; +}; + +// ======================================================================================= + +static constexpr auto HELLO_WORLD = "Hello, world!"_kj; + +class NullInputStream final: public kj::AsyncInputStream { +public: + NullInputStream(kj::Maybe expectedLength = size_t(0)) + : expectedLength(expectedLength) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return size_t(0); + } + + kj::Maybe tryGetLength() override { + return expectedLength; + } + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return uint64_t(0); + } + +private: + kj::Maybe expectedLength; +}; + +class VectorOutputStream: public kj::AsyncOutputStream { +public: + kj::String consume() { + chars.add('\0'); + return kj::String(chars.releaseAsArray()); + } + + kj::Promise write(const void* buffer, size_t size) override { + chars.addAll(kj::arrayPtr(reinterpret_cast(buffer), size)); + return kj::READY_NOW; + } + + kj::Promise write(kj::ArrayPtr> pieces) override { + for (auto piece: pieces) { + chars.addAll(piece.asChars()); + } + return kj::READY_NOW; + } + + kj::Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + +private: + kj::Vector chars; +}; + +class MockService: public kj::HttpService { +public: + MockService(kj::HttpHeaderTable::Builder& headerTableBuilder) + : headerTable(headerTableBuilder.getFutureTable()), + customHeaderId(headerTableBuilder.add("X-Custom-Header")) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_ASSERT(method == kj::HttpMethod::GET); + KJ_ASSERT(url == "http://foo"_kj); + KJ_ASSERT(headers.get(customHeaderId) == "corge"_kj); + + kj::HttpHeaders responseHeaders(headerTable); + responseHeaders.set(kj::HttpHeaderId::CONTENT_TYPE, "text/plain"); + responseHeaders.set(customHeaderId, "foobar"_kj); + auto stream = response.send(200, "OK", responseHeaders); + auto promise = stream->write(HELLO_WORLD.begin(), HELLO_WORLD.size()); + return promise.attach(kj::mv(stream)); + } + +private: + const kj::HttpHeaderTable& headerTable; + kj::HttpHeaderId customHeaderId; +}; + +class MockSender: private kj::HttpService::Response { +public: + MockSender(kj::HttpHeaderTable::Builder& headerTableBuilder) + : headerTable(headerTableBuilder.getFutureTable()), + customHeaderId(headerTableBuilder.add("X-Custom-Header")) {} + + kj::Promise sendRequest(kj::HttpClient& client) { + kj::HttpHeaders headers(headerTable); + headers.set(customHeaderId, "corge"_kj); + auto req = client.request(kj::HttpMethod::GET, "http://foo"_kj, headers); + req.body = nullptr; + auto resp = co_await req.response; + KJ_ASSERT(resp.statusCode == 200); + KJ_ASSERT(resp.statusText == "OK"_kj); + KJ_ASSERT(resp.headers->get(customHeaderId) == "foobar"_kj); + + auto body = co_await resp.body->readAllText(); + KJ_ASSERT(body == HELLO_WORLD); + } + + kj::Promise sendRequest(kj::HttpService& service) { + kj::HttpHeaders headers(headerTable); + headers.set(customHeaderId, "corge"_kj); + NullInputStream requestBody; + co_await service.request(kj::HttpMethod::GET, "http://foo"_kj, headers, requestBody, *this); + KJ_ASSERT(responseBody.consume() == HELLO_WORLD); + } + +private: + const kj::HttpHeaderTable& headerTable; + kj::HttpHeaderId customHeaderId; + + VectorOutputStream responseBody; + + kj::Own send( + uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_ASSERT(statusCode == 200); + KJ_ASSERT(statusText == "OK"_kj); + KJ_ASSERT(headers.get(customHeaderId) == "foobar"_kj); + + return kj::attachRef(responseBody); + } + + kj::Own acceptWebSocket(const kj::HttpHeaders& headers) override { + KJ_UNIMPLEMENTED("no WebSockets here"); + } +}; + +KJ_TEST("Benchmark baseline") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + doBenchmark([&]() { + sender.sendRequest(service).wait(waitScope); + }); +} + +KJ_TEST("Benchmark KJ HTTP client wrapper") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + auto client = kj::newHttpClient(service); + + doBenchmark([&]() { + sender.sendRequest(*client).wait(waitScope); + }); +} + +KJ_TEST("Benchmark KJ HTTP full protocol") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + Metrics::StreamPair pair(metrics); + kj::TimerImpl timer(kj::origin()); + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + kj::HttpServer server(timer, *headerTable, service); + auto listenLoop = server.listenHttp({&pair.server, kj::NullDisposer::instance}) + .eagerlyEvaluate([](kj::Exception&& e) noexcept { kj::throwFatalException(kj::mv(e)); }); + auto client = kj::newHttpClient(*headerTable, pair.client); + + doBenchmark([&]() { + sender.sendRequest(*client).wait(waitScope); + }); +} + +KJ_TEST("Benchmark HTTP-over-capnp local call") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + HttpOverCapnpFactory::HeaderIdBundle headerIds(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + // Client and server use different HttpOverCapnpFactory instances to block path-shortening. + ByteStreamFactory bsFactory; + HttpOverCapnpFactory hocFactory(bsFactory, headerIds.clone(), HttpOverCapnpFactory::LEVEL_2); + ByteStreamFactory bsFactory2; + HttpOverCapnpFactory hocFactory2(bsFactory2, kj::mv(headerIds), HttpOverCapnpFactory::LEVEL_2); + + auto cap = hocFactory.kjToCapnp(kj::attachRef(service)); + auto roundTrip = hocFactory2.capnpToKj(kj::mv(cap)); + + doBenchmark([&]() { + sender.sendRequest(*roundTrip).wait(waitScope); + }); +} + +KJ_TEST("Benchmark HTTP-over-capnp full RPC") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + Metrics::StreamPair pair(metrics); + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + HttpOverCapnpFactory::HeaderIdBundle headerIds(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + // Client and server use different HttpOverCapnpFactory instances to block path-shortening. + ByteStreamFactory bsFactory; + HttpOverCapnpFactory hocFactory(bsFactory, headerIds.clone(), HttpOverCapnpFactory::LEVEL_2); + ByteStreamFactory bsFactory2; + HttpOverCapnpFactory hocFactory2(bsFactory2, kj::mv(headerIds), HttpOverCapnpFactory::LEVEL_2); + + TwoPartyServer server(hocFactory.kjToCapnp(kj::attachRef(service))); + + auto pipe = kj::newTwoWayPipe(); + auto listenLoop = server.accept(pair.server); + + TwoPartyClient client(pair.client); + + auto roundTrip = hocFactory2.capnpToKj(client.bootstrap().castAs()); + + doBenchmark([&]() { + sender.sendRequest(*roundTrip).wait(waitScope); + }); +} + +} // namespace +} // namespace capnp + +#endif // KJ_HAS_COROUTINE diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-test.c++ index 771e25a9e0b..425014cabbb 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp-test.c++ @@ -22,6 +22,10 @@ #include "http-over-capnp.h" #include +#ifndef TEST_PEER_OPTIMIZATION_LEVEL +#define TEST_PEER_OPTIMIZATION_LEVEL HttpOverCapnpFactory::LEVEL_2 +#endif + namespace capnp { namespace { @@ -42,7 +46,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { auto buffer = kj::heapArray(expected.size()); auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); - return promise.then(kj::mvCapture(buffer, [&in,expected](kj::Array buffer, size_t amount) { + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { if (amount == 0) { KJ_FAIL_ASSERT("expected data never sent", expected); } @@ -53,7 +57,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { } return expectRead(in, expected.slice(amount)); - })); + }); } enum Direction { @@ -402,8 +406,8 @@ KJ_TEST("HTTP-over-Cap'n-Proto E2E, no path shortening") { ByteStreamFactory streamFactory1; ByteStreamFactory streamFactory2; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory1(streamFactory1, tableBuilder); - HttpOverCapnpFactory factory2(streamFactory2, tableBuilder); + HttpOverCapnpFactory factory1(streamFactory1, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + HttpOverCapnpFactory factory2(streamFactory2, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); auto headerTable = tableBuilder.build(); runEndToEndTests(timer, *headerTable, factory1, factory2, waitScope); @@ -416,7 +420,7 @@ KJ_TEST("HTTP-over-Cap'n-Proto E2E, with path shortening") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); auto headerTable = tableBuilder.build(); runEndToEndTests(timer, *headerTable, factory, factory, waitScope); @@ -438,7 +442,7 @@ KJ_TEST("HTTP-over-Cap'n-Proto 205 bug with HttpClientAdapter") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); auto headerTable = tableBuilder.build(); auto pipe = kj::newTwoWayPipe(); @@ -511,45 +515,29 @@ private: kj::Promise done; }; -void runWebSocketTests(kj::HttpHeaderTable& headerTable, - HttpOverCapnpFactory& clientFactory, HttpOverCapnpFactory& serverFactory, - kj::WaitScope& waitScope) { - // We take a different approach here, because writing out raw WebSocket frames is a pain. - // It's easier to test WebSockets at the KJ API level. - - auto wsPaf = kj::newPromiseAndFulfiller>(); - auto donePaf = kj::newPromiseAndFulfiller(); - - auto back = serverFactory.kjToCapnp(kj::heap( - headerTable, kj::mv(wsPaf.fulfiller), kj::mv(donePaf.promise))); - auto front = clientFactory.capnpToKj(back); - auto client = kj::newHttpClient(*front); - - auto resp = client->openWebSocket("/ws", kj::HttpHeaders(headerTable)).wait(waitScope); - KJ_ASSERT(resp.webSocketOrBody.is>()); - - auto clientWs = kj::mv(resp.webSocketOrBody.get>()); - auto serverWs = wsPaf.promise.wait(waitScope); +void runWebSocketBasicTestCase( + kj::WebSocket& clientWs, kj::WebSocket& serverWs, kj::WaitScope& waitScope) { + // Called by `runWebSocketTests()`. { - auto promise = clientWs->send("foo"_kj); - auto message = serverWs->receive().wait(waitScope); + auto promise = clientWs.send("foo"_kj); + auto message = serverWs.receive().wait(waitScope); promise.wait(waitScope); KJ_ASSERT(message.is()); KJ_EXPECT(message.get() == "foo"); } { - auto promise = serverWs->send("bar"_kj.asBytes()); - auto message = clientWs->receive().wait(waitScope); + auto promise = serverWs.send("bar"_kj.asBytes()); + auto message = clientWs.receive().wait(waitScope); promise.wait(waitScope); KJ_ASSERT(message.is>()); KJ_EXPECT(kj::str(message.get>().asChars()) == "bar"); } { - auto promise = clientWs->close(1234, "baz"_kj); - auto message = serverWs->receive().wait(waitScope); + auto promise = clientWs.close(1234, "baz"_kj); + auto message = serverWs.receive().wait(waitScope); promise.wait(waitScope); KJ_ASSERT(message.is()); KJ_EXPECT(message.get().code == 1234); @@ -557,14 +545,53 @@ void runWebSocketTests(kj::HttpHeaderTable& headerTable, } { - auto promise = serverWs->disconnect(); - auto receivePromise = clientWs->receive(); + auto promise = serverWs.disconnect(); + auto receivePromise = clientWs.receive(); KJ_EXPECT(receivePromise.poll(waitScope)); KJ_EXPECT_THROW(DISCONNECTED, receivePromise.wait(waitScope)); promise.wait(waitScope); } } +void runWebSocketAbortTestCase( + kj::WebSocket& clientWs, kj::WebSocket& serverWs, kj::WaitScope& waitScope) { + auto onAbort = serverWs.whenAborted(); + KJ_EXPECT(!onAbort.poll(waitScope)); + clientWs.abort(); + + // At one time, this promise hung forever. + KJ_EXPECT(onAbort.poll(waitScope)); + onAbort.wait(waitScope); +} + +void runWebSocketTests(kj::HttpHeaderTable& headerTable, + HttpOverCapnpFactory& clientFactory, HttpOverCapnpFactory& serverFactory, + kj::WaitScope& waitScope) { + // We take a different approach here, because writing out raw WebSocket frames is a pain. + // It's easier to test WebSockets at the KJ API level. + + for (auto testCase: { + runWebSocketBasicTestCase, + runWebSocketAbortTestCase, + }) { + auto wsPaf = kj::newPromiseAndFulfiller>(); + auto donePaf = kj::newPromiseAndFulfiller(); + + auto back = serverFactory.kjToCapnp(kj::heap( + headerTable, kj::mv(wsPaf.fulfiller), kj::mv(donePaf.promise))); + auto front = clientFactory.capnpToKj(back); + auto client = kj::newHttpClient(*front); + + auto resp = client->openWebSocket("/ws", kj::HttpHeaders(headerTable)).wait(waitScope); + KJ_ASSERT(resp.webSocketOrBody.is>()); + + auto clientWs = kj::mv(resp.webSocketOrBody.get>()); + auto serverWs = wsPaf.promise.wait(waitScope); + + testCase(*clientWs, *serverWs, waitScope); + } +} + KJ_TEST("HTTP-over-Cap'n Proto WebSocket, no path shortening") { kj::EventLoop eventLoop; kj::WaitScope waitScope(eventLoop); @@ -572,8 +599,8 @@ KJ_TEST("HTTP-over-Cap'n Proto WebSocket, no path shortening") { ByteStreamFactory streamFactory1; ByteStreamFactory streamFactory2; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory1(streamFactory1, tableBuilder); - HttpOverCapnpFactory factory2(streamFactory2, tableBuilder); + HttpOverCapnpFactory factory1(streamFactory1, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + HttpOverCapnpFactory factory2(streamFactory2, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); auto headerTable = tableBuilder.build(); runWebSocketTests(*headerTable, factory1, factory2, waitScope); @@ -585,7 +612,7 @@ KJ_TEST("HTTP-over-Cap'n Proto WebSocket, with path shortening") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); auto headerTable = tableBuilder.build(); runWebSocketTests(*headerTable, factory, factory, waitScope); @@ -620,7 +647,7 @@ KJ_TEST("HttpService isn't destroyed while call outstanding") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); auto headerTable = tableBuilder.build(); bool called = false; @@ -644,5 +671,319 @@ KJ_TEST("HttpService isn't destroyed while call outstanding") { KJ_EXPECT(!destroyed); } + +class ConnectWriteCloseService final: public kj::HttpService { + // A simple CONNECT server that will accept a connection, write some data and close the + // connection. +public: + ConnectWriteCloseService(kj::HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect( + kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& io, + kj::HttpService::ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", kj::HttpHeaders(headerTable)); + return io.write("test", 4).then([&io]() mutable { + io.shutdownWrite(); + }); + } + +private: + kj::HttpHeaderTable& headerTable; +}; + +class ConnectWriteRespService final: public kj::HttpService { +public: + ConnectWriteRespService(kj::HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect( + kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& io, + kj::HttpService::ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", kj::HttpHeaders(headerTable)); + // TODO(later): `io.pumpTo(io).ignoreResult;` doesn't work here, + // it causes startTls to come back in a loop. The below avoids this. + auto buffer = kj::heapArray(4096); + return manualPumpLoop(buffer, io).attach(kj::mv(buffer)); + } + + kj::Promise manualPumpLoop(kj::ArrayPtr buffer, kj::AsyncIoStream& io) { + return io.tryRead(buffer.begin(), 1, buffer.size()).then( + [this,&io,buffer](size_t amount) mutable -> kj::Promise { + if (amount == 0) { return kj::READY_NOW; } + return io.write(buffer.begin(), amount).then([this,&io,buffer]() mutable -> kj::Promise { + return manualPumpLoop(buffer, io); + }); + }); + } + +private: + kj::HttpHeaderTable& headerTable; +}; + +class ConnectRejectService final: public kj::HttpService { + // A simple CONNECT server that will reject a connection. +public: + ConnectRejectService(kj::HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect( + kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& io, + kj::HttpService::ConnectResponse& response, + kj::HttpConnectSettings settings) override { + auto body = response.reject(500, "Internal Server Error", kj::HttpHeaders(headerTable), 5); + return body->write("Error", 5).attach(kj::mv(body)); + } + +private: + kj::HttpHeaderTable& headerTable; +}; + +KJ_TEST("HTTP-over-Cap'n-Proto Connect with close") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + auto pipe = kj::newTwoWayPipe(); + + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder); + kj::Own table = tableBuilder.build(); + ConnectWriteCloseService service(*table); + kj::HttpServer server(timer, *table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(*table, *pipe.ends[1]); + + capnp::HttpService::Client httpService = factory.kjToCapnp(newHttpService(*client)); + auto frontCapnpHttpService = factory.capnpToKj(httpService); + + struct ResponseImpl final: public kj::HttpService::ConnectResponse { + kj::Own> fulfiller; + ResponseImpl(kj::Own> fulfiller) + : fulfiller(kj::mv(fulfiller)) {} + void accept(uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers) override { + KJ_REQUIRE(statusCode >= 200 && statusCode < 300, "the statusCode must be 2xx for accept"); + fulfiller->fulfill( + kj::HttpClient::ConnectRequest::Status( + statusCode, + kj::str(statusText), + kj::heap(headers.clone()), + nullptr + ) + ); + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + KJ_UNREACHABLE; + } + }; + + auto clientPipe = kj::newTwoWayPipe(); + auto paf = kj::newPromiseAndFulfiller(); + ResponseImpl response(kj::mv(paf.fulfiller)); + + auto promise = frontCapnpHttpService->connect( + "https://example.org"_kj, kj::HttpHeaders(*table), *clientPipe.ends[0], + response, {}).attach(kj::mv(clientPipe.ends[0])); + + paf.promise.then( + [io = kj::mv(clientPipe.ends[1])](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + + auto buf = kj::heapArray(4); + return io->tryRead(buf.begin(), 4, 4).then( + [buf = kj::mv(buf), io = kj::mv(io)](size_t count) mutable { + KJ_ASSERT(count == 4, "Expecting the stream to read 4 chars."); + return io->tryRead(buf.begin(), 1, 1).then( + [buf = kj::mv(buf)](size_t count) mutable { + KJ_ASSERT(count == 0, "Expecting the stream to get disconnected."); + }).attach(kj::mv(io)); + }); + }).wait(waitScope); + + listenTask.wait(waitScope); +} + + +KJ_TEST("HTTP-over-Cap'n-Proto Connect Reject") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + auto pipe = kj::newTwoWayPipe(); + + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder); + kj::Own table = tableBuilder.build(); + ConnectRejectService service(*table); + kj::HttpServer server(timer, *table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(*table, *pipe.ends[1]); + + capnp::HttpService::Client httpService = factory.kjToCapnp(newHttpService(*client)); + auto frontCapnpHttpService = factory.capnpToKj(httpService); + + struct ResponseImpl final: public kj::HttpService::ConnectResponse { + kj::Own>> fulfiller; + ResponseImpl(kj::Own>> fulfiller) + : fulfiller(kj::mv(fulfiller)) {} + void accept(uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers) override { + KJ_UNREACHABLE; + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + KJ_ASSERT(statusCode == 500); + KJ_ASSERT(statusText == "Internal Server Error"); + KJ_ASSERT(expectedBodySize.orDefault(5)); + auto pipe = kj::newOneWayPipe(); + fulfiller->fulfill(kj::mv(pipe.in)); + return kj::mv(pipe.out); + } + }; + + auto clientPipe = kj::newTwoWayPipe(); + auto paf = kj::newPromiseAndFulfiller>(); + ResponseImpl response(kj::mv(paf.fulfiller)); + + auto promise = frontCapnpHttpService->connect( + "https://example.org"_kj, kj::HttpHeaders(*table), *clientPipe.ends[0], + response, {}).attach(kj::mv(clientPipe.ends[0])); + + paf.promise.then( + [](auto body) mutable { + auto buf = kj::heapArray(5); + return body->tryRead(buf.begin(), 5, 5).then( + [buf = kj::mv(buf), body = kj::mv(body)](size_t count) mutable { + KJ_ASSERT(count == 5, "Expecting the stream to read 5 chars."); + }); + }).attach(kj::mv(promise)).wait(waitScope); + + listenTask.wait(waitScope); +} + +kj::Promise expectEnd(kj::AsyncInputStream& in) { + static char buffer; + + auto promise = in.tryRead(&buffer, 1, 1); + return promise.then([](size_t amount) { + KJ_ASSERT(amount == 0, "expected EOF"); + }); +} + +KJ_TEST("HTTP-over-Cap'n-Proto Connect with startTls") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + auto pipe = kj::newTwoWayPipe(); + + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder); + kj::Own table = tableBuilder.build(); + ConnectWriteRespService service(*table); + kj::HttpServer server(timer, *table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(*table, *pipe.ends[1]); + + class WrapperHttpClient final: public kj::HttpClient { + public: + kj::HttpClient& inner; + + WrapperHttpClient(kj::HttpClient& client) : inner(client) {}; + + kj::Promise openWebSocket( + kj::StringPtr url, const kj::HttpHeaders& headers) override { KJ_UNREACHABLE; } + Request request(kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { KJ_UNREACHABLE; } + + ConnectRequest connect(kj::StringPtr host, const kj::HttpHeaders& headers, + kj::HttpConnectSettings settings) override { + KJ_IF_MAYBE(starter, settings.tlsStarter) { + *starter = [](kj::StringPtr) { + return kj::READY_NOW; + }; + } + + return inner.connect(host, headers, settings); + } + }; + + // Only need this wrapper to define a dummy tlsStarter. + auto wrappedClient = kj::heap(*client); + capnp::HttpService::Client httpService = factory.kjToCapnp(newHttpService(*wrappedClient)); + auto frontCapnpHttpService = factory.capnpToKj(httpService); + + auto frontCapnpHttpClient = kj::newHttpClient(*frontCapnpHttpService); + + kj::Own tlsStarter = kj::heap(); + kj::HttpConnectSettings settings = { .useTls = false }; + settings.tlsStarter = tlsStarter; + + auto request = frontCapnpHttpClient->connect( + "https://example.org"_kj, kj::HttpHeaders(*table), settings); + + KJ_ASSERT_NONNULL(*tlsStarter); + + request.status.then( + [io=kj::mv(request.connection), &tlsStarter](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + + return KJ_ASSERT_NONNULL(*tlsStarter)("example.com").then([io = kj::mv(io)]() mutable { + return io->write("hello", 5).then([io = kj::mv(io)]() mutable { + auto buffer = kj::heapArray(5); + return io->tryRead(buffer.begin(), 5, 5).then( + [io = kj::mv(io), buffer = kj::mv(buffer)](size_t) mutable { + io->shutdownWrite(); + return expectEnd(*io).attach(kj::mv(io)); + }); + }); + }); + }).wait(waitScope); + + listenTask.wait(waitScope); +} + } // namespace } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.c++ index f76e309f69c..a92e7c59d1a 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.c++ @@ -22,6 +22,7 @@ #include "http-over-capnp.h" #include #include +#include namespace capnp { @@ -119,7 +120,11 @@ public: } kj::Maybe> shortenPath() override { - return kj::mv(shorteningPromise); + auto onAbort = webSocket.whenAborted() + .then([]() -> kj::Promise { + return KJ_EXCEPTION(DISCONNECTED, "WebSocket was aborted"); + }); + return shorteningPromise.exclusiveJoin(kj::mv(onAbort)); } kj::Promise sendText(SendTextContext context) override { @@ -322,15 +327,72 @@ private: // Must check state->assertNotCanceled() before using this. }; +class HttpOverCapnpFactory::ConnectClientRequestContextImpl final + : public capnp::HttpService::ConnectClientRequestContext::Server { +public: + ConnectClientRequestContextImpl(HttpOverCapnpFactory& factory, + kj::HttpService::ConnectResponse& connResponse) + : factory(factory), connResponse(connResponse) {} + + kj::Promise startConnect(StartConnectContext context) override { + KJ_REQUIRE(!sent, "already called startConnect() or startError()"); + sent = true; + + auto params = context.getParams(); + auto resp = params.getResponse(); + + auto headers = factory.headersToKj(resp.getHeaders()); + connResponse.accept(resp.getStatusCode(), resp.getStatusText(), headers); + + return kj::READY_NOW; + } + + kj::Promise startError(StartErrorContext context) override { + KJ_REQUIRE(!sent, "already called startConnect() or startError()"); + sent = true; + + auto params = context.getParams(); + auto resp = params.getResponse(); + + auto headers = factory.headersToKj(resp.getHeaders()); + + auto bodySize = resp.getBodySize(); + kj::Maybe expectedSize; + if (bodySize.isFixed()) { + expectedSize = bodySize.getFixed(); + } + + auto stream = connResponse.reject( + resp.getStatusCode(), resp.getStatusText(), headers, expectedSize); + + context.initResults().setBody(factory.streamFactory.kjToCapnp(kj::mv(stream))); + + return kj::READY_NOW; + } + +private: + HttpOverCapnpFactory& factory; + bool sent = false; + + kj::HttpService::ConnectResponse& connResponse; +}; + class HttpOverCapnpFactory::KjToCapnpHttpServiceAdapter final: public kj::HttpService { public: KjToCapnpHttpServiceAdapter(HttpOverCapnpFactory& factory, capnp::HttpService::Client inner) : factory(factory), inner(kj::mv(inner)) {} - kj::Promise request( + template + kj::Promise requestImpl( + Request rpcRequest, kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, - kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse) override { - auto rpcRequest = inner.startRequestRequest(); + kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse, + AwaitCompletionFunc&& awaitCompletion) { + // Common implementation calling request() or startRequest(). awaitCompletion() waits for + // final completion in a method-specific way. + // + // TODO(cleanup): When we move to C++17 or newer we can use `if constexpr` instead of a + // callback. auto metadata = rpcRequest.initRequest(); metadata.setMethod(static_cast(method)); @@ -369,9 +431,12 @@ public: // Pump upstream -- unless we don't expect a request body. kj::Maybe> pumpRequestTask; KJ_IF_MAYBE(rb, maybeRequestBody) { - auto bodyOut = factory.streamFactory.capnpToKj(pipeline.getRequestBody()); - pumpRequestTask = rb->pumpTo(*bodyOut).attach(kj::mv(bodyOut)).ignoreResult() - .eagerlyEvaluate([state = kj::addRef(*state)](kj::Exception&& e) mutable { + auto bodyOut = factory.streamFactory.capnpToKjExplicitEnd(pipeline.getRequestBody()); + pumpRequestTask = rb->pumpTo(*bodyOut) + .then([&bodyOut = *bodyOut](uint64_t) mutable { + return bodyOut.end(); + }).eagerlyEvaluate([state = kj::addRef(*state), bodyOut = kj::mv(bodyOut)] + (kj::Exception&& e) mutable { // A DISCONNECTED exception probably means the server decided not to read the whole request // before responding. In that case we simply want the pump to end, so that on this end it // also appears that the service simply didn't read everything. So we don't propagate the @@ -383,10 +448,10 @@ public: }); } - // Wait for the ServerRequestContext to resolve, which indicates completion. Meanwhile, if the - // promise is canceled from the client side, we drop the ServerRequestContext naturally, and we + // Wait for the server to indicate completion. Meanwhile, if the + // promise is canceled from the client side, we propagate cancellation naturally, and we // also call state->cancel(). - return pipeline.getContext().whenResolved() + return awaitCompletion(pipeline) // Once the server indicates it is done, then we can cancel pumping the request, because // obviously the server won't use it. We should not cancel pumping the response since there // could be data in-flight still. @@ -396,6 +461,76 @@ public: .attach(kj::mv(deferredCancel)); } + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse) override { + if (factory.peerOptimizationLevel < LEVEL_2) { + return requestImpl(inner.startRequestRequest(), method, url, headers, requestBody, kjResponse, + [](auto& pipeline) { return pipeline.getContext().whenResolved(); }); + } else { + return requestImpl(inner.requestRequest(), method, url, headers, requestBody, kjResponse, + [](auto& pipeline) { return pipeline.ignoreResult(); }); + } + } + + kj::Promise connect( + kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& connection, + ConnectResponse& tunnel, kj::HttpConnectSettings settings) override { + auto rpcRequest = inner.connectRequest(); + auto downPipe = kj::newOneWayPipe(); + rpcRequest.setHost(host); + rpcRequest.setDown(factory.streamFactory.kjToCapnp(kj::mv(downPipe.out))); + rpcRequest.initSettings().setUseTls(settings.useTls); + + auto context = kj::heap(factory, tunnel); + RevocableServer revocableContext(*context); + + auto builder = capnp::Request< + capnp::HttpService::ConnectParams, + capnp::HttpService::ConnectResults>::Builder(rpcRequest); + rpcRequest.adoptHeaders(factory.headersToCapnp(headers, + Orphanage::getForMessageContaining(builder))); + rpcRequest.setContext(revocableContext.getClient()); + RemotePromise pipeline = rpcRequest.send(); + + // We read from `downPipe` (the other side writes into it.) + auto downPumpTask = downPipe.in->pumpTo(connection) + .then([&connection, down = kj::mv(downPipe.in)](uint64_t) -> kj::Promise { + connection.shutdownWrite(); + return kj::NEVER_DONE; + }); + // We write to `up` (the other side reads from it). + auto up = pipeline.getUp(); + + // We need to create a tlsStarter callback which sends a startTls request to the capnp server. + KJ_IF_MAYBE(tlsStarter, settings.tlsStarter) { + kj::Function(kj::StringPtr)> cb = + [upForStartTls = kj::cp(up)] + (kj::StringPtr expectedServerHostname) + mutable -> kj::Promise { + auto startTlsRpcRequest = upForStartTls.startTlsRequest(); + startTlsRpcRequest.setExpectedServerHostname(expectedServerHostname); + return startTlsRpcRequest.send(); + }; + *tlsStarter = kj::mv(cb); + } + + auto upStream = factory.streamFactory.capnpToKjExplicitEnd(up); + auto upPumpTask = connection.pumpTo(*upStream) + .then([&upStream = *upStream](uint64_t) mutable { + return upStream.end(); + }).then([up = kj::mv(up), upStream = kj::mv(upStream)]() mutable + -> kj::Promise { + return kj::NEVER_DONE; + }); + + return pipeline.ignoreResult() + .attach(kj::mv(downPumpTask), kj::mv(upPumpTask), kj::mv(revocableContext)) + // Separate attach to make sure `revocableContext` is destroyed before `context`. + .attach(kj::mv(context)); + } + + private: HttpOverCapnpFactory& factory; capnp::HttpService::Client inner; @@ -415,7 +550,7 @@ class NullInputStream final: public kj::AsyncInputStream { public: kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return size_t(0); + return kj::constPromise(); } kj::Maybe tryGetLength() override { @@ -423,7 +558,7 @@ public: } kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { - return uint64_t(0); + return kj::constPromise(); } }; @@ -452,38 +587,17 @@ public: } // namespace -class HttpOverCapnpFactory::ServerRequestContextImpl final - : public capnp::HttpService::ServerRequestContext::Server, - public kj::HttpService::Response { +class HttpOverCapnpFactory::HttpServiceResponseImpl + : public kj::HttpService::Response { public: - ServerRequestContextImpl(HttpOverCapnpFactory& factory, - HttpService::Client serviceCap, - capnp::HttpRequest::Reader request, - capnp::HttpService::ClientRequestContext::Client clientContext, - kj::Own requestBodyIn, - kj::HttpService& kjService) - : factory(factory), serviceCap(kj::mv(serviceCap)), + HttpServiceResponseImpl(HttpOverCapnpFactory& factory, + capnp::HttpRequest::Reader request, + capnp::HttpService::ClientRequestContext::Client clientContext) + : factory(factory), method(validateMethod(request.getMethod())), - url(kj::str(request.getUrl())), - headers(factory.headersToKj(request.getHeaders()).clone()), - clientContext(kj::mv(clientContext)), - // Note we attach `requestBodyIn` to `task` so that we will implicitly cancel reading - // the request body as soon as the service returns. This is important in particular when - // the request body is not fully consumed, in order to propagate cancellation. - task(kjService.request(method, url, headers, *requestBodyIn, *this) - .attach(kj::mv(requestBodyIn))) {} - - KJ_DISALLOW_COPY(ServerRequestContextImpl); - - kj::Maybe> shortenPath() override { - return task.then([]() -> Capability::Client { - // If all went well, resolve to a settled capability. - // TODO(perf): Could save a message by resolving to a capability hosted by the client, or - // some special "null" capability that isn't an error but is still transmitted by value. - // Otherwise we need a Release message from client -> server just to drop this... - return kj::heap(); - }); - } + url(request.getUrl()), + headers(factory.headersToKj(request.getHeaders())), + clientContext(kj::mv(clientContext)) {} kj::Own send( uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers, @@ -508,19 +622,16 @@ public: hasBody = *s > 0; } + auto logError = [hasBody](kj::Exception&& e) { + KJ_LOG(INFO, "HTTP-over-RPC startResponse() failed", hasBody, e); + }; if (hasBody) { auto pipeline = req.send(); auto result = factory.streamFactory.capnpToKj(pipeline.getBody()); - replyTask = pipeline.ignoreResult() - .eagerlyEvaluate([](kj::Exception&& e) { - KJ_LOG(ERROR, "HTTP-over-RPC startResponse() failed", e); - }); + replyTask = pipeline.ignoreResult().eagerlyEvaluate(kj::mv(logError)); return result; } else { - replyTask = req.send().ignoreResult() - .eagerlyEvaluate([](kj::Exception&& e) { - KJ_LOG(ERROR, "HTTP-over-RPC startResponse() failed", e); - }); + replyTask = req.send().ignoreResult().eagerlyEvaluate(kj::mv(logError)); return kj::heap(); } @@ -560,21 +671,18 @@ public: // since it holds a reference to `downSocket`. replyTask = pipeline.ignoreResult() .eagerlyEvaluate([](kj::Exception&& e) { - KJ_LOG(ERROR, "HTTP-over-RPC startWebSocketRequest() failed", e); + KJ_LOG(INFO, "HTTP-over-RPC startWebSocketRequest() failed", e); }); return result; } -private: HttpOverCapnpFactory& factory; - HttpService::Client serviceCap; // ensures the inner kj::HttpService isn't destroyed kj::HttpMethod method; - kj::String url; + kj::StringPtr url; kj::HttpHeaders headers; capnp::HttpService::ClientRequestContext::Client clientContext; kj::Maybe> replyTask; - kj::Promise task; static kj::HttpMethod validateMethod(capnp::HttpMethod method) { KJ_REQUIRE(method <= capnp::HttpMethod::UNSUBSCRIBE, "unknown method", method); @@ -582,12 +690,116 @@ private: } }; +class HttpOverCapnpFactory::HttpOverCapnpConnectResponseImpl final : + public kj::HttpService::ConnectResponse { +public: + HttpOverCapnpConnectResponseImpl( + HttpOverCapnpFactory& factory, + capnp::HttpService::ConnectClientRequestContext::Client context) : + context(context), factory(factory) {} + + void accept(uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers) override { + KJ_REQUIRE(replyTask == nullptr, "already called accept() or reject()"); + + auto req = context.startConnectRequest(); + auto rpcResponse = req.initResponse(); + rpcResponse.setStatusCode(statusCode); + rpcResponse.setStatusText(statusText); + rpcResponse.adoptHeaders(factory.headersToCapnp( + headers, Orphanage::getForMessageContaining(rpcResponse))); + + replyTask = req.send().ignoreResult(); + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_REQUIRE(replyTask == nullptr, "already called accept() or reject()"); + auto pipe = kj::newOneWayPipe(expectedBodySize); + + auto req = context.startErrorRequest(); + auto rpcResponse = req.initResponse(); + rpcResponse.setStatusCode(statusCode); + rpcResponse.setStatusText(statusText); + rpcResponse.adoptHeaders(factory.headersToCapnp( + headers, Orphanage::getForMessageContaining(rpcResponse))); + + auto errorBody = kj::mv(pipe.in); + // Set the body size if the error body exists. + KJ_IF_MAYBE(size, errorBody->tryGetLength()) { + rpcResponse.getBodySize().setFixed(*size); + } + + replyTask = req.send().then( + [this, errorBody = kj::mv(errorBody)](auto resp) mutable -> kj::Promise { + auto body = factory.streamFactory.capnpToKjExplicitEnd(resp.getBody()); + return errorBody->pumpTo(*body) + .then([&body = *body](uint64_t) mutable { + return body.end(); + }).attach(kj::mv(errorBody), kj::mv(body)); + }); + + return kj::mv(pipe.out); + } + + capnp::HttpService::ConnectClientRequestContext::Client context; + capnp::HttpOverCapnpFactory& factory; + kj::Maybe> replyTask; +}; + + +class HttpOverCapnpFactory::ServerRequestContextImpl final + : public capnp::HttpService::ServerRequestContext::Server, + public HttpServiceResponseImpl { +public: + ServerRequestContextImpl(HttpOverCapnpFactory& factory, + HttpService::Client serviceCap, + kj::Own request, + capnp::HttpService::ClientRequestContext::Client clientContext, + kj::Own requestBodyIn, + kj::HttpService& kjService) + : HttpServiceResponseImpl(factory, *request, kj::mv(clientContext)), + request(kj::mv(request)), + serviceCap(kj::mv(serviceCap)), + // Note we attach `requestBodyIn` to `task` so that we will implicitly cancel reading + // the request body as soon as the service returns. This is important in particular when + // the request body is not fully consumed, in order to propagate cancellation. + task(kjService.request(method, url, headers, *requestBodyIn, *this) + .attach(kj::mv(requestBodyIn))) {} + + kj::Maybe> shortenPath() override { + return task.then([]() -> Capability::Client { + // If all went well, resolve to a settled capability. + // TODO(perf): Could save a message by resolving to a capability hosted by the client, or + // some special "null" capability that isn't an error but is still transmitted by value. + // Otherwise we need a Release message from client -> server just to drop this... + return kj::heap(); + }); + } + + KJ_DISALLOW_COPY_AND_MOVE(ServerRequestContextImpl); + +private: + kj::Own request; + HttpService::Client serviceCap; // ensures the inner kj::HttpService isn't destroyed + kj::Promise task; +}; + class HttpOverCapnpFactory::CapnpToKjHttpServiceAdapter final: public capnp::HttpService::Server { public: CapnpToKjHttpServiceAdapter(HttpOverCapnpFactory& factory, kj::Own inner) : factory(factory), inner(kj::mv(inner)) {} - kj::Promise startRequest(StartRequestContext context) override { + template + kj::Promise requestImpl(CallContext context, Callback&& callback) { + // Common implementation of request() and startRequest(). callback() performs the + // method-specific stuff at the end. + // + // TODO(cleanup): When we move to C++17 or newer we can use `if constexpr` instead of a + // callback. + auto params = context.getParams(); auto metadata = params.getRequest(); @@ -604,15 +816,118 @@ public: kj::Own requestBody; if (hasBody) { auto pipe = kj::newOneWayPipe(expectedSize); - results.setRequestBody(factory.streamFactory.kjToCapnp(kj::mv(pipe.out))); + auto requestBodyCap = factory.streamFactory.kjToCapnp(kj::mv(pipe.out)); + + if (kj::isSameType()) { + // For request(), use context.setPipeline() to enable pipelined calls to the request body + // stream before this RPC completes. (We don't bother when using startRequest() because + // it returns immediately anyway, so this would just waste effort.) + PipelineBuilder pipeline; + pipeline.setRequestBody(kj::cp(requestBodyCap)); + context.setPipeline(pipeline.build()); + } + + results.setRequestBody(kj::mv(requestBodyCap)); requestBody = kj::mv(pipe.in); } else { requestBody = kj::heap(); } - results.setContext(kj::heap( - factory, thisCap(), metadata, params.getContext(), kj::mv(requestBody), *inner)); - return kj::READY_NOW; + return callback(results, metadata, params, requestBody); + } + + kj::Promise request(RequestContext context) override { + return requestImpl(kj::mv(context), + [&](auto& results, auto& metadata, auto& params, auto& requestBody) { + class FinalHttpServiceResponseImpl final: public HttpServiceResponseImpl { + public: + using HttpServiceResponseImpl::HttpServiceResponseImpl; + }; + auto impl = kj::heap(factory, metadata, params.getContext()); + auto promise = inner->request(impl->method, impl->url, impl->headers, *requestBody, *impl); + return promise.attach(kj::mv(requestBody), kj::mv(impl)); + }); + } + + kj::Promise startRequest(StartRequestContext context) override { + return requestImpl(kj::mv(context), + [&](auto& results, auto& metadata, auto& params, auto& requestBody) { + results.setContext(kj::heap( + factory, thisCap(), capnp::clone(metadata), params.getContext(), kj::mv(requestBody), + *inner)); + + return kj::READY_NOW; + }); + } + + kj::Promise connect(ConnectContext context) override { + auto params = context.getParams(); + auto host = params.getHost(); + kj::Own tlsStarter = kj::heap(); + kj::HttpConnectSettings settings = { .useTls = params.getSettings().getUseTls()}; + settings.tlsStarter = tlsStarter; + auto headers = factory.headersToKj(params.getHeaders()); + auto pipe = kj::newTwoWayPipe(); + + class EofDetector final: public kj::AsyncOutputStream { + public: + EofDetector(kj::Own inner) + : inner(kj::mv(inner)) {} + ~EofDetector() { + inner->shutdownWrite(); + } + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + return inner->tryPumpFrom(input, amount); + } + + kj::Promise write(const void* buffer, size_t size) override { + return inner->write(buffer, size); + } + + kj::Promise write(kj::ArrayPtr> pieces) override { + return inner->write(pieces); + } + + kj::Promise whenWriteDisconnected() override { + return inner->whenWriteDisconnected(); + } + private: + kj::Own inner; + }; + + auto stream = factory.streamFactory.capnpToKjExplicitEnd(context.getParams().getDown()); + + // We want to keep the stream alive even after EofDetector is destroyed, so we need to create + // a refcounted AsyncIoStream. + auto refcounted = kj::refcountedWrapper(kj::mv(pipe.ends[1])); + kj::Own ref1 = refcounted->addWrappedRef(); + kj::Own ref2 = refcounted->addWrappedRef(); + + // We write to the `down` pipe. + auto pumpTask = ref1->pumpTo(*stream) + .then([&stream = *stream](uint64_t) mutable { + return stream.end(); + }).then([httpProxyStream = kj::mv(ref1), stream = kj::mv(stream)]() mutable + -> kj::Promise { + return kj::NEVER_DONE; + }); + + PipelineBuilder pb; + auto eofWrapper = kj::heap(kj::mv(ref2)); + auto up = factory.streamFactory.kjToCapnp(kj::mv(eofWrapper), kj::mv(tlsStarter)); + pb.setUp(kj::cp(up)); + + context.setPipeline(pb.build()); + context.initResults(capnp::MessageSize { 4, 1 }).setUp(kj::mv(up)); + + auto response = kj::heap( + factory, context.getParams().getContext()); + + return inner->connect(host, headers, *pipe.ends[0], *response, settings).attach( + kj::mv(host), kj::mv(headers), kj::mv(response), kj::mv(pipe)) + .exclusiveJoin(kj::mv(pumpTask)); } private: @@ -658,8 +973,10 @@ HttpOverCapnpFactory::HeaderIdBundle HttpOverCapnpFactory::HeaderIdBundle::clone } HttpOverCapnpFactory::HttpOverCapnpFactory(ByteStreamFactory& streamFactory, - HeaderIdBundle headerIds) + HeaderIdBundle headerIds, + OptimizationLevel peerOptimizationLevel) : streamFactory(streamFactory), headerTable(headerIds.table), + peerOptimizationLevel(peerOptimizationLevel), nameCapnpToKj(kj::mv(headerIds.nameCapnpToKj)) { auto commonHeaderNames = Schema::from().getEnumerants(); nameKjToCapnp = kj::heapArray(headerIds.maxHeaderId + 1); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.capnp b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.capnp index eb8578de696..8b4afba0433 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.capnp +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.capnp @@ -24,16 +24,50 @@ using import "byte-stream.capnp".ByteStream; -$import "/capnp/c++.capnp".namespace("capnp"); +using Cxx = import "/capnp/c++.capnp"; +$Cxx.namespace("capnp"); +$Cxx.allowCancellation; interface HttpService { - startRequest @0 (request :HttpRequest, context :ClientRequestContext) - -> (requestBody :ByteStream, context :ServerRequestContext); - # Begin an HTTP request. + request @1 (request :HttpRequest, context :ClientRequestContext) + -> (requestBody :ByteStream); + # Perform an HTTP request. # # The client sends the request method/url/headers. The server responds with a `ByteStream` where # the client can make calls to stream up the request body. `requestBody` will be null in the case # that request.bodySize.fixed == 0. + # + # The server will send a response by invoking a method on `callback`. + # + # `request()` does not return until the server is completely done processing the request, + # including sending the response. The client therefore must use promise pipelining to send the + # request body. The client may request cancellation of the HTTP request by canceling the + # `request()` call itself. + + startRequest @0 (request :HttpRequest, context :ClientRequestContext) + -> (requestBody :ByteStream, context :ServerRequestContext); + # DEPRECATED: Older form of `request()`. In this version, the server immediately returns a + # `ServerRequestContext` before it begins processing the request. This version was designed + # before `CallContext::setPipeline()` was introduced. At that time, it was impossible for the + # server to receive data sent to the `requestBody` stream until `startRequest()` had returned + # a stream capability to use, hence the ongoing call on the server side had to be represented + # using a separate capability. Now that we have `CallContext::setPipeline()`, the server can + # begin receiving the request body without returning from the top-level RPC, so we can now use + # `request()` instead of `startRequest()`. The new approach is more intuitive and avoids some + # unnecessary bookkeeping. + # + # `HttpOverCapnpFactory` will continue to support both methods. Use the `peerOptimizationLevel` + # constructor parameter to specify which method to use, for backwards-compatibiltiy purposes. + + connect @2 (host :Text, headers :List(HttpHeader), down :ByteStream, + context :ConnectClientRequestContext, settings :ConnectSettings) + -> (up :ByteStream); + # Setup an HTTP CONNECT proxy tunnel. + # + # The client sends the request host/headers together with a `down` ByteStream that will be used + # for communication across the tunnel. The server will respond with the other side of that + # ByteStream for two-way communication. The `context` includes callbacks which are used to + # supply the client with headers. interface ClientRequestContext { # Provides callbacks for the server to send the response. @@ -52,7 +86,20 @@ interface HttpService { # Server -> Client will be sent as calls to `downSocket`. } + interface ConnectClientRequestContext { + # Provides callbacks for the server to send the response. + + startConnect @0 (response :HttpResponse); + # Server calls this method to let the client know that the CONNECT request has been + # accepted. It also includes status code and header information. + + startError @1 (response :HttpResponse) -> (body :ByteStream); + # Server calls this method to let the client know that the CONNECT request has been rejected. + } + interface ServerRequestContext { + # DEPRECATED: Used only with startRequest(); see comments there. + # # Represents execution of a particular request on the server side. # # Dropping this object before the request completes will cancel the request. @@ -65,6 +112,10 @@ interface HttpService { } } +struct ConnectSettings { + useTls @0 :Bool; +} + interface WebSocket { sendText @0 (text :Text) -> stream; sendData @1 (data :Data) -> stream; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.h index aedd6f0d835..6b16749118b 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/http-over-capnp.h @@ -27,6 +27,8 @@ #include #include "byte-stream.h" +CAPNP_BEGIN_HEADER + namespace capnp { class HttpOverCapnpFactory { @@ -50,7 +52,21 @@ class HttpOverCapnpFactory { friend class HttpOverCapnpFactory; }; - HttpOverCapnpFactory(ByteStreamFactory& streamFactory, HeaderIdBundle headerIds); + enum OptimizationLevel { + // Specifies the protocol optimization level supported by the remote peer. Setting this higher + // will improve efficiency but breaks compatibility with older peers that don't implement newer + // levels. + + LEVEL_1, + // Use startRequest(), the original version of the protocol. + + LEVEL_2 + // Use request(). This is more efficient than startRequest() but won't work with old peers that + // only implement startRequest(). + }; + + HttpOverCapnpFactory(ByteStreamFactory& streamFactory, HeaderIdBundle headerIds, + OptimizationLevel peerOptimizationLevel = LEVEL_1); kj::Own capnpToKj(capnp::HttpService::Client rpcService); capnp::HttpService::Client kjToCapnp(kj::Own service); @@ -58,6 +74,7 @@ class HttpOverCapnpFactory { private: ByteStreamFactory& streamFactory; const kj::HttpHeaderTable& headerTable; + OptimizationLevel peerOptimizationLevel; kj::Array nameKjToCapnp; kj::Array nameCapnpToKj; kj::Array valueCapnpToKj; @@ -69,8 +86,11 @@ class HttpOverCapnpFactory { class KjToCapnpWebSocketAdapter; class ClientRequestContextImpl; + class ConnectClientRequestContextImpl; class KjToCapnpHttpServiceAdapter; + class HttpServiceResponseImpl; + class HttpOverCapnpConnectResponseImpl; class ServerRequestContextImpl; class CapnpToKjHttpServiceAdapter; @@ -82,3 +102,5 @@ class HttpOverCapnpFactory { }; } // namespace capnp + +CAPNP_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json-rpc.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json-rpc.h index 6954c644ede..c4d3b997001 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json-rpc.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json-rpc.h @@ -26,6 +26,8 @@ #include #include +CAPNP_BEGIN_HEADER + namespace kj { class HttpInputStream; } namespace capnp { @@ -42,7 +44,7 @@ class JsonRpc: private kj::TaskSet::ErrorHandler { class ContentLengthTransport; JsonRpc(Transport& transport, DynamicCapability::Client interface = {}); - KJ_DISALLOW_COPY(JsonRpc); + KJ_DISALLOW_COPY_AND_MOVE(JsonRpc); DynamicCapability::Client getPeer(InterfaceSchema schema); @@ -98,7 +100,7 @@ class JsonRpc::ContentLengthTransport: public Transport { public: explicit ContentLengthTransport(kj::AsyncIoStream& stream); ~ContentLengthTransport() noexcept(false); - KJ_DISALLOW_COPY(ContentLengthTransport); + KJ_DISALLOW_COPY_AND_MOVE(ContentLengthTransport); kj::Promise send(kj::StringPtr text) override; kj::Promise receive() override; @@ -110,3 +112,5 @@ class JsonRpc::ContentLengthTransport: public Transport { }; } // namespace capnp + +CAPNP_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json-test.c++ index 7a632ebab83..550ecafa7d2 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json-test.c++ @@ -50,6 +50,17 @@ KJ_TEST("basic json encoding") { KJ_EXPECT(json.encode(Data::Reader(bytes, 3)) == "[12, 34, 56]"); } +KJ_TEST("raw encoding") { + JsonCodec json; + + auto text = kj::str("{\"field\":\"value\"}"); + MallocMessageBuilder message; + auto value = message.initRoot(); + value.setRaw(text); + + KJ_EXPECT(json.encodeRaw(value) == text); +} + const char ALL_TYPES_JSON[] = "{ \"voidField\": null,\n" " \"boolField\": true,\n" @@ -606,6 +617,17 @@ KJ_TEST("basic json decoding") { KJ_EXPECT_THROW_MESSAGE("Unexpected input", json.decodeRaw("\f{}", root)); KJ_EXPECT_THROW_MESSAGE("Unexpected input", json.decodeRaw("{\v}", root)); } + + { + MallocMessageBuilder message; + auto root = message.initRoot(); + + json.decodeRaw(R"("\u007f")", root); + KJ_EXPECT(root.which() == JsonValue::STRING); + + char utf_buffer[] = {127, 0}; + KJ_EXPECT(kj::str(utf_buffer) == root.getString()); + } } KJ_TEST("maximum nesting depth") { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.c++ index dd8b07b0e12..83a522f4ee6 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.c++ @@ -93,6 +93,10 @@ struct JsonCodec::Impl { return kj::strTree(call.getFunction(), '(', encodeList( kj::mv(encodedElements), childMultiline, indent, multiline, true), ')'); } + + case JsonValue::RAW: { + return kj::strTree(value.getRaw()); + } } KJ_FAIL_ASSERT("unknown JsonValue type", static_cast(value.which())); @@ -741,7 +745,7 @@ public: private: kj::String consumeQuotedString() { input.consume('"'); - // TODO(perf): Avoid copy / alloc if no escapes encoutered. + // TODO(perf): Avoid copy / alloc if no escapes encountered. // TODO(perf): Get statistics on string size and preallocate? kj::Vector decoded; @@ -819,9 +823,9 @@ private: if ('0' <= c && c <= '9') { codePoint |= c - '0'; } else if ('a' <= c && c <= 'f') { - codePoint |= c - 'a'; + codePoint |= c - 'a' + 10; } else if ('A' <= c && c <= 'F') { - codePoint |= c - 'A'; + codePoint |= c - 'A' + 10; } else { KJ_FAIL_REQUIRE("Invalid hex digit in unicode escape.", c); } @@ -1324,7 +1328,7 @@ private: const void* getUnionInstanceIdentifier(DynamicStruct::Builder obj) const { // Gets a value uniquely identifying an instance of a union. - // HACK: We return a poniter to the union's discriminant within the underlying buffer. + // HACK: We return a pointer to the union's discriminant within the underlying buffer. return reinterpret_cast( AnyStruct::Reader(obj.asReader()).getDataSection().begin()) + discriminantOffset; } diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp index ca8a8f3ebac..e5d1870c00f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp @@ -44,6 +44,18 @@ struct Value { # "binary" and "date" types in text, since JSON has no analog of these. This is basically the # reason this extension exists. We do NOT recommend using `call` unless you specifically need # to be compatible with some silly format that uses this syntax. + + raw @7 :Text; + # Used to indicate that the text should be written directly to the output without + # modifications. Use this if you have an already serialized JSON value and don't want + # to feel the cost of deserializing the value just to serialize it again. + # + # The parser will never produce a `raw` value -- this is only useful for serialization. + # + # WARNING: You MUST ensure that the value is valid stand-alone JSOn. It will not be verified. + # Invalid JSON could mjake the whole message unparsable. Worse, a malicious raw value could + # perform JSON injection attacks. Make sure that the value was produced by a trustworthy JSON + # encoder. } struct Field { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp.c++ index 3556247666f..faf41e50660 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp.c++ @@ -5,17 +5,17 @@ namespace capnp { namespace schemas { -static const ::capnp::_::AlignedData<137> b_a3fa7845f919dd83 = { +static const ::capnp::_::AlignedData<152> b_a3fa7845f919dd83 = { { 0, 0, 0, 0, 5, 0, 6, 0, 131, 221, 25, 249, 69, 120, 250, 163, 24, 0, 0, 0, 1, 0, 2, 0, 52, 94, 58, 164, 151, 146, 249, 142, - 1, 0, 7, 0, 0, 0, 7, 0, + 1, 0, 7, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 242, 0, 0, 0, 33, 0, 0, 0, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 53, 0, 0, 0, 143, 1, 0, 0, + 53, 0, 0, 0, 199, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 99, 111, @@ -29,56 +29,63 @@ static const ::capnp::_::AlignedData<137> b_a3fa7845f919dd83 = { 5, 0, 0, 0, 42, 0, 0, 0, 70, 105, 101, 108, 100, 0, 0, 0, 67, 97, 108, 108, 0, 0, 0, 0, - 28, 0, 0, 0, 3, 0, 4, 0, + 32, 0, 0, 0, 3, 0, 4, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 181, 0, 0, 0, 42, 0, 0, 0, + 209, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 176, 0, 0, 0, 3, 0, 1, 0, - 188, 0, 0, 0, 2, 0, 1, 0, + 204, 0, 0, 0, 3, 0, 1, 0, + 216, 0, 0, 0, 2, 0, 1, 0, 1, 0, 254, 255, 16, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 185, 0, 0, 0, 66, 0, 0, 0, + 213, 0, 0, 0, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 180, 0, 0, 0, 3, 0, 1, 0, - 192, 0, 0, 0, 2, 0, 1, 0, + 208, 0, 0, 0, 3, 0, 1, 0, + 220, 0, 0, 0, 2, 0, 1, 0, 2, 0, 253, 255, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 189, 0, 0, 0, 58, 0, 0, 0, + 217, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 184, 0, 0, 0, 3, 0, 1, 0, - 196, 0, 0, 0, 2, 0, 1, 0, + 212, 0, 0, 0, 3, 0, 1, 0, + 224, 0, 0, 0, 2, 0, 1, 0, 3, 0, 252, 255, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 193, 0, 0, 0, 58, 0, 0, 0, + 221, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 188, 0, 0, 0, 3, 0, 1, 0, - 200, 0, 0, 0, 2, 0, 1, 0, + 216, 0, 0, 0, 3, 0, 1, 0, + 228, 0, 0, 0, 2, 0, 1, 0, 4, 0, 251, 255, 0, 0, 0, 0, 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 197, 0, 0, 0, 50, 0, 0, 0, + 225, 0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 192, 0, 0, 0, 3, 0, 1, 0, - 220, 0, 0, 0, 2, 0, 1, 0, + 220, 0, 0, 0, 3, 0, 1, 0, + 248, 0, 0, 0, 2, 0, 1, 0, 5, 0, 250, 255, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 217, 0, 0, 0, 58, 0, 0, 0, + 245, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 212, 0, 0, 0, 3, 0, 1, 0, - 240, 0, 0, 0, 2, 0, 1, 0, + 240, 0, 0, 0, 3, 0, 1, 0, + 12, 1, 0, 0, 2, 0, 1, 0, 6, 0, 249, 255, 0, 0, 0, 0, 0, 0, 1, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 237, 0, 0, 0, 42, 0, 0, 0, + 9, 1, 0, 0, 42, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 4, 1, 0, 0, 3, 0, 1, 0, + 16, 1, 0, 0, 2, 0, 1, 0, + 7, 0, 248, 255, 0, 0, 0, 0, + 0, 0, 1, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 13, 1, 0, 0, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 232, 0, 0, 0, 3, 0, 1, 0, - 244, 0, 0, 0, 2, 0, 1, 0, + 8, 1, 0, 0, 3, 0, 1, 0, + 20, 1, 0, 0, 2, 0, 1, 0, 110, 117, 108, 108, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -141,6 +148,14 @@ static const ::capnp::_::AlignedData<137> b_a3fa7845f919dd83 = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 114, 97, 119, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; @@ -151,11 +166,11 @@ static const ::capnp::_::RawSchema* const d_a3fa7845f919dd83[] = { &s_a3fa7845f919dd83, &s_e31026e735d69ddf, }; -static const uint16_t m_a3fa7845f919dd83[] = {4, 1, 6, 0, 2, 5, 3}; -static const uint16_t i_a3fa7845f919dd83[] = {0, 1, 2, 3, 4, 5, 6}; +static const uint16_t m_a3fa7845f919dd83[] = {4, 1, 6, 0, 2, 5, 7, 3}; +static const uint16_t i_a3fa7845f919dd83[] = {0, 1, 2, 3, 4, 5, 6, 7}; const ::capnp::_::RawSchema s_a3fa7845f919dd83 = { - 0xa3fa7845f919dd83, b_a3fa7845f919dd83.words, 137, d_a3fa7845f919dd83, m_a3fa7845f919dd83, - 3, 7, i_a3fa7845f919dd83, nullptr, nullptr, { &s_a3fa7845f919dd83, nullptr, nullptr, 0, 0, nullptr } + 0xa3fa7845f919dd83, b_a3fa7845f919dd83.words, 152, d_a3fa7845f919dd83, m_a3fa7845f919dd83, + 3, 8, i_a3fa7845f919dd83, nullptr, nullptr, { &s_a3fa7845f919dd83, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_e31026e735d69ddf = { @@ -218,7 +233,7 @@ static const uint16_t m_e31026e735d69ddf[] = {0, 1}; static const uint16_t i_e31026e735d69ddf[] = {0, 1}; const ::capnp::_::RawSchema s_e31026e735d69ddf = { 0xe31026e735d69ddf, b_e31026e735d69ddf.words, 49, d_e31026e735d69ddf, m_e31026e735d69ddf, - 1, 2, i_e31026e735d69ddf, nullptr, nullptr, { &s_e31026e735d69ddf, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_e31026e735d69ddf, nullptr, nullptr, { &s_e31026e735d69ddf, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<54> b_a0d9f6eca1c93d48 = { @@ -286,7 +301,7 @@ static const uint16_t m_a0d9f6eca1c93d48[] = {0, 1}; static const uint16_t i_a0d9f6eca1c93d48[] = {0, 1}; const ::capnp::_::RawSchema s_a0d9f6eca1c93d48 = { 0xa0d9f6eca1c93d48, b_a0d9f6eca1c93d48.words, 54, d_a0d9f6eca1c93d48, m_a0d9f6eca1c93d48, - 1, 2, i_a0d9f6eca1c93d48, nullptr, nullptr, { &s_a0d9f6eca1c93d48, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_a0d9f6eca1c93d48, nullptr, nullptr, { &s_a0d9f6eca1c93d48, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<21> b_fa5b1fd61c2e7c3d = { @@ -316,7 +331,7 @@ static const ::capnp::_::AlignedData<21> b_fa5b1fd61c2e7c3d = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_fa5b1fd61c2e7c3d = { 0xfa5b1fd61c2e7c3d, b_fa5b1fd61c2e7c3d.words, 21, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_fa5b1fd61c2e7c3d, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_fa5b1fd61c2e7c3d, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<21> b_82d3e852af0336bf = { @@ -346,7 +361,7 @@ static const ::capnp::_::AlignedData<21> b_82d3e852af0336bf = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_82d3e852af0336bf = { 0x82d3e852af0336bf, b_82d3e852af0336bf.words, 21, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_82d3e852af0336bf, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_82d3e852af0336bf, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<35> b_c4df13257bc2ea61 = { @@ -392,7 +407,7 @@ static const uint16_t m_c4df13257bc2ea61[] = {0}; static const uint16_t i_c4df13257bc2ea61[] = {0}; const ::capnp::_::RawSchema s_c4df13257bc2ea61 = { 0xc4df13257bc2ea61, b_c4df13257bc2ea61.words, 35, nullptr, m_c4df13257bc2ea61, - 0, 1, i_c4df13257bc2ea61, nullptr, nullptr, { &s_c4df13257bc2ea61, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_c4df13257bc2ea61, nullptr, nullptr, { &s_c4df13257bc2ea61, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<22> b_cfa794e8d19a0162 = { @@ -423,7 +438,7 @@ static const ::capnp::_::AlignedData<22> b_cfa794e8d19a0162 = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_cfa794e8d19a0162 = { 0xcfa794e8d19a0162, b_cfa794e8d19a0162.words, 22, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_cfa794e8d19a0162, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_cfa794e8d19a0162, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_c2f8c20c293e5319 = { @@ -485,7 +500,7 @@ static const uint16_t m_c2f8c20c293e5319[] = {0, 1}; static const uint16_t i_c2f8c20c293e5319[] = {0, 1}; const ::capnp::_::RawSchema s_c2f8c20c293e5319 = { 0xc2f8c20c293e5319, b_c2f8c20c293e5319.words, 51, nullptr, m_c2f8c20c293e5319, - 0, 2, i_c2f8c20c293e5319, nullptr, nullptr, { &s_c2f8c20c293e5319, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_c2f8c20c293e5319, nullptr, nullptr, { &s_c2f8c20c293e5319, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<21> b_d7d879450a253e4b = { @@ -515,7 +530,7 @@ static const ::capnp::_::AlignedData<21> b_d7d879450a253e4b = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_d7d879450a253e4b = { 0xd7d879450a253e4b, b_d7d879450a253e4b.words, 21, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_d7d879450a253e4b, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_d7d879450a253e4b, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<21> b_f061e22f0ae5c7b5 = { @@ -545,7 +560,7 @@ static const ::capnp::_::AlignedData<21> b_f061e22f0ae5c7b5 = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_f061e22f0ae5c7b5 = { 0xf061e22f0ae5c7b5, b_f061e22f0ae5c7b5.words, 21, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_f061e22f0ae5c7b5, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_f061e22f0ae5c7b5, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<22> b_a0a054dea32fd98c = { @@ -576,7 +591,7 @@ static const ::capnp::_::AlignedData<22> b_a0a054dea32fd98c = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_a0a054dea32fd98c = { 0xa0a054dea32fd98c, b_a0a054dea32fd98c.words, 22, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_a0a054dea32fd98c, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_a0a054dea32fd98c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -588,43 +603,63 @@ namespace capnp { namespace json { // Value +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Value::_capnpPrivate::dataWordSize; constexpr uint16_t Value::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Value::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Value::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Value::Field +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Value::Field::_capnpPrivate::dataWordSize; constexpr uint16_t Value::Field::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Value::Field::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Value::Field::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Value::Call +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Value::Call::_capnpPrivate::dataWordSize; constexpr uint16_t Value::Call::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Value::Call::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Value::Call::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // FlattenOptions +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t FlattenOptions::_capnpPrivate::dataWordSize; constexpr uint16_t FlattenOptions::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind FlattenOptions::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* FlattenOptions::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // DiscriminatorOptions +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t DiscriminatorOptions::_capnpPrivate::dataWordSize; constexpr uint16_t DiscriminatorOptions::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind DiscriminatorOptions::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* DiscriminatorOptions::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp.h index 1454e3e6494..35a218ce5bb 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.capnp.h @@ -9,7 +9,9 @@ #include #endif // !CAPNP_LITE -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif @@ -51,6 +53,7 @@ struct Value { ARRAY, OBJECT, CALL, + RAW, }; struct Field; struct Call; @@ -168,6 +171,10 @@ class Value::Reader { inline bool hasCall() const; inline ::capnp::json::Value::Call::Reader getCall() const; + inline bool isRaw() const; + inline bool hasRaw() const; + inline ::capnp::Text::Reader getRaw() const; + private: ::capnp::_::StructReader _reader; template @@ -241,6 +248,14 @@ class Value::Builder { inline void adoptCall(::capnp::Orphan< ::capnp::json::Value::Call>&& value); inline ::capnp::Orphan< ::capnp::json::Value::Call> disownCall(); + inline bool isRaw(); + inline bool hasRaw(); + inline ::capnp::Text::Builder getRaw(); + inline void setRaw( ::capnp::Text::Reader value); + inline ::capnp::Text::Builder initRaw(unsigned int size); + inline void adoptRaw(::capnp::Orphan< ::capnp::Text>&& value); + inline ::capnp::Orphan< ::capnp::Text> disownRaw(); + private: ::capnp::_::StructBuilder _builder; template @@ -927,6 +942,60 @@ inline ::capnp::Orphan< ::capnp::json::Value::Call> Value::Builder::disownCall() ::capnp::bounded<0>() * ::capnp::POINTERS)); } +inline bool Value::Reader::isRaw() const { + return which() == Value::RAW; +} +inline bool Value::Builder::isRaw() { + return which() == Value::RAW; +} +inline bool Value::Reader::hasRaw() const { + if (which() != Value::RAW) return false; + return !_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline bool Value::Builder::hasRaw() { + if (which() != Value::RAW) return false; + return !_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::Text::Reader Value::Reader::getRaw() const { + KJ_IREQUIRE((which() == Value::RAW), + "Must check which() before get()ing a union member."); + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline ::capnp::Text::Builder Value::Builder::getRaw() { + KJ_IREQUIRE((which() == Value::RAW), + "Must check which() before get()ing a union member."); + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline void Value::Builder::setRaw( ::capnp::Text::Reader value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::RAW); + ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), value); +} +inline ::capnp::Text::Builder Value::Builder::initRaw(unsigned int size) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::RAW); + return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), size); +} +inline void Value::Builder::adoptRaw( + ::capnp::Orphan< ::capnp::Text>&& value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::RAW); + ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Text> Value::Builder::disownRaw() { + KJ_IREQUIRE((which() == Value::RAW), + "Must check which() before get()ing a union member."); + return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} + inline bool Value::Field::Reader::hasName() const { return !_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.h index f5dbf38b423..8ce477ed377 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/json.h @@ -25,6 +25,8 @@ #include #include +CAPNP_BEGIN_HEADER + namespace capnp { typedef json::Value JsonValue; @@ -523,3 +525,5 @@ void JsonCodec::handleByAnnotation() { } } // namespace capnp + +CAPNP_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/std-iterator.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/std-iterator.h index 1e6d7b947a0..aac249d7b10 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/std-iterator.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/std-iterator.h @@ -29,11 +29,19 @@ #include "../list.h" #include +CAPNP_BEGIN_HEADER + namespace std { template -struct iterator_traits> - : public std::iterator {}; +struct iterator_traits> { + using iterator_category = std::random_access_iterator_tag; + using value_type = Element; + using difference_type = int; + using pointer = Element*; + using reference = Element; +}; } // namespace std +CAPNP_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/websocket-rpc.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/websocket-rpc.h index a94b27adc24..80c8ae25373 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compat/websocket-rpc.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compat/websocket-rpc.h @@ -24,6 +24,8 @@ #include #include +CAPNP_BEGIN_HEADER + namespace capnp { class WebSocketMessageStream final : public MessageStream { @@ -51,3 +53,5 @@ class WebSocketMessageStream final : public MessageStream { }; } // namespace capnp + +CAPNP_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnp-test.sh b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnp-test.sh index b35e5072e3e..c435a3d3a88 100755 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnp-test.sh +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnp-test.sh @@ -119,6 +119,9 @@ test_eval 'TestListDefaults.lists.int32ListList[2][0]' 12341234 test "x`$CAPNP eval $SCHEMA -ojson globalPrintableStruct | tr -d '\r'`" = "x{\"someText\": \"foo\"}" || fail eval json "globalPrintableStruct == {someText = \"foo\"}" +$CAPNP eval $TESTDATA/no-file-id.capnp.nobuild foo >/dev/null || fail eval "file without file ID can be parsed" +test "x`$CAPNP eval $TESTDATA/no-file-id.capnp.nobuild foo | tr -d '\r'`" = 'x"bar"' || fail eval "file without file ID parsed correctly" + $CAPNP compile --no-standard-import --src-prefix="$PREFIX" -ofoo $TESTDATA/errors.capnp.nobuild 2>&1 | sed -e "s,^.*errors[.]capnp[.]nobuild:,file:,g" | tr -d '\r' | diff -u $TESTDATA/errors.txt - || fail error output diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnp.c++ index 79569546dac..b00b41ddc5d 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnp.c++ @@ -23,6 +23,11 @@ #define _GNU_SOURCE #endif +#ifndef _FILE_OFFSET_BITS +#define _FILE_OFFSET_BITS 64 +// Request 64-bit off_t and ino_t, otherwise this code will break when either value exceeds 2^32. +#endif + #if _WIN32 #include #endif @@ -249,6 +254,12 @@ public: // Default convert to text unless -o is given. convertTo = Format::TEXT; + // When using `capnp eval`, type IDs don't really matter, because `eval` won't actually use + // them for anything. When using Cap'n Proto an a config format -- the common use case for + // `capnp eval` -- the exercise of adding a file ID to every file is pointless busy work. So, + // we don't require it. + loader.setFileIdsRequired(false); + kj::MainBuilder builder(context, VERSION_STRING, "Prints (or encodes) the value of , which must be defined in . " " must refer to a const declaration, a field of a struct type (prints the default " diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnpc-c++.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnpc-c++.c++ index 853375f4267..a60d4770e36 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnpc-c++.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/capnpc-c++.c++ @@ -64,6 +64,7 @@ namespace { static constexpr uint64_t NAMESPACE_ANNOTATION_ID = 0xb9c6f99ebf805f2cull; static constexpr uint64_t NAME_ANNOTATION_ID = 0xf264a779fef191ceull; +static constexpr uint64_t ALLOW_CANCELLATION_ANNOTATION_ID = 0xac7096ff8cfc9dceull; bool hasDiscriminantValue(const schema::Field::Reader& reader) { return reader.getDiscriminantValue() != schema::Field::NO_DISCRIMINANT; @@ -430,7 +431,7 @@ private: #if 0 // Figure out exactly how many params are not bound to AnyPointer. - // TODO(msvc): In a few obscure cases, MSVC does not like empty template pramater lists, + // TODO(msvc): In a few obscure cases, MSVC does not like empty template parameter lists, // even if all parameters have defaults. So, we give in and explicitly list all // parameters in our generated code for now. Try again later. uint paramCount = 0; @@ -2045,11 +2046,15 @@ private: kj::StringTree defineText = kj::strTree( "// ", fullName, "\n", + "#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n", templates, "constexpr uint16_t ", fullName, "::_capnpPrivate::dataWordSize;\n", - templates, "constexpr uint16_t ", fullName, "::_capnpPrivate::pointerCount;\n" + templates, "constexpr uint16_t ", fullName, "::_capnpPrivate::pointerCount;\n", + "#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n", "#if !CAPNP_LITE\n", + "#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n", templates, "constexpr ::capnp::Kind ", fullName, "::_capnpPrivate::kind;\n", - templates, "constexpr ::capnp::_::RawSchema const* ", fullName, "::_capnpPrivate::schema;\n"); + templates, "constexpr ::capnp::_::RawSchema const* ", fullName, "::_capnpPrivate::schema;\n", + "#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n"); if (templateContext.isGeneric()) { auto brandInitializers = makeBrandInitializers(templateContext, schema); @@ -2236,6 +2241,8 @@ private: // the `CAPNP_AUTO_IF_MSVC()` hackery in the return type declarations below. We're depending on // the fact that that this function has an inline implementation for the deduction to work. + bool noPromisePipelining = !resultSchema.mayContainCapabilities(); + auto requestMethodImpl = kj::strTree( templateContext.allDecls(), implicitParamsTemplateDecl, @@ -2247,9 +2254,24 @@ private: isStreaming ? kj::strTree(" return newStreamingCall<", paramType, ">(\n") : kj::strTree(" return newCall<", paramType, ", ", resultType, ">(\n"), - " 0x", interfaceIdHex, "ull, ", methodId, ", sizeHint);\n" + " 0x", interfaceIdHex, "ull, ", methodId, ", sizeHint, {", noPromisePipelining, "});\n" "}\n"); + bool allowCancellation = false; + if (annotationValue(proto, ALLOW_CANCELLATION_ANNOTATION_ID) != nullptr) { + allowCancellation = true; + } else if (annotationValue(interfaceProto, ALLOW_CANCELLATION_ANNOTATION_ID) != nullptr) { + allowCancellation = true; + } else { + schema::Node::Reader node = interfaceProto; + while (!node.isFile()) { + node = schemaLoader.get(node.getScopeId()).getProto(); + } + if (annotationValue(node, ALLOW_CANCELLATION_ANNOTATION_ID) != nullptr) { + allowCancellation = true; + } + } + return MethodText { kj::strTree( implicitParamsTemplateDecl.size() == 0 ? "" : " ", implicitParamsTemplateDecl, @@ -2297,7 +2319,8 @@ private: " return ", identifierName, "(::capnp::Capability::Server::internalGetTypedStreamingContext<\n" " ", genericParamType, ">(context));\n" " }),\n" - " true\n" + " true,\n" + " ", allowCancellation, "\n" " };\n") : kj::strTree( // For non-streaming calls we let exceptions just flow through for a little more @@ -2305,7 +2328,8 @@ private: " return {\n" " ", identifierName, "(::capnp::Capability::Server::internalGetTypedContext<\n" " ", genericParamType, ", ", genericResultType, ">(context)),\n" - " false\n" + " false,\n" + " ", allowCancellation, "\n" " };\n")) }; } @@ -2372,8 +2396,10 @@ private: kj::StringTree defineText = kj::strTree( "// ", fullName, "\n", "#if !CAPNP_LITE\n", + "#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n", templates, "constexpr ::capnp::Kind ", fullName, "::_capnpPrivate::kind;\n", - templates, "constexpr ::capnp::_::RawSchema const* ", fullName, "::_capnpPrivate::schema;\n"); + templates, "constexpr ::capnp::_::RawSchema const* ", fullName, "::_capnpPrivate::schema;\n" + "#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n"); if (templateContext.isGeneric()) { auto brandInitializers = makeBrandInitializers(templateContext, schema); @@ -2583,9 +2609,7 @@ private: kj::strTree("static constexpr ", typeName_, ' ', upperCase, " = ", literalValue(schema.getType(), constProto.getValue()), ";\n"), scope.size() == 0 ? kj::strTree() : kj::strTree( - // TODO(msvc): MSVC doesn't like definitions of constexprs, but other compilers and - // the standard require them. - "#if !defined(_MSC_VER) || defined(__clang__)\n" + "#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n" "constexpr ", typeName_, ' ', scope, upperCase, ";\n" "#endif\n") }; @@ -2771,6 +2795,8 @@ private: auto brandDeps = makeBrandDepInitializers( makeBrandDepMap(templateContext, schema.getGeneric())); + bool mayContainCapabilities = proto.isStruct() && schema.asStruct().mayContainCapabilities(); + auto schemaDef = kj::strTree( "static const ::capnp::_::AlignedData<", rawSchema.size(), "> b_", hexId, " = {\n" " {", kj::mv(schemaLiteral), " }\n" @@ -2803,7 +2829,7 @@ private: ", nullptr, nullptr, { &s_", hexId, ", nullptr, ", brandDeps.size() == 0 ? kj::strTree("nullptr, 0, 0") : kj::strTree( "bd_", hexId, ", 0, " "sizeof(bd_", hexId, ") / sizeof(bd_", hexId, "[0])"), - ", nullptr }\n" + ", nullptr }, ", mayContainCapabilities, "\n" "};\n" "#endif // !CAPNP_LITE\n"); @@ -3042,7 +3068,9 @@ private: "#endif // !CAPNP_LITE\n" ) : kj::strTree(), "\n" - "#if CAPNP_VERSION != ", CAPNP_VERSION, "\n" + "#ifndef CAPNP_VERSION\n" + "#error \"CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?\"\n" + "#elif CAPNP_VERSION != ", CAPNP_VERSION, "\n" "#error \"Version mismatch between generated code and library headers. You must " "use the same version of the Cap'n Proto compiler and library.\"\n" "#endif\n" @@ -3152,6 +3180,8 @@ private: schemaLoader.load(node); } + schemaLoader.computeOptimizationHints(); + for (auto requestedFile: request.getRequestedFiles()) { auto schema = schemaLoader.get(requestedFile.getId()); auto fileText = makeFileText(schema, requestedFile); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/compiler.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/compiler.h index 36c5dca5f52..375bf59d43b 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/compiler.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/compiler.h @@ -76,7 +76,7 @@ class Compiler final: private SchemaLoader::LazyLoadCallback { explicit Compiler(AnnotationFlag annotationFlag = COMPILE_ANNOTATIONS); ~Compiler() noexcept(false); - KJ_DISALLOW_COPY(Compiler); + KJ_DISALLOW_COPY_AND_MOVE(Compiler); class CompiledType { // Represents a compiled type expression, from which you can traverse to nested types, apply @@ -197,11 +197,11 @@ class Compiler final: private SchemaLoader::LazyLoadCallback { // dependencies. PARENTS = 1 << 1, - // Eagerly compile all lexical parents of the requested node. Only meaningful in conjuction + // Eagerly compile all lexical parents of the requested node. Only meaningful in conjunction // with NODE. CHILDREN = 1 << 2, - // Eagerly compile all of the node's lexically nested nodes. Only meaningful in conjuction + // Eagerly compile all of the node's lexically nested nodes. Only meaningful in conjunction // with NODE. DEPENDENCIES = NODE << 15, diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/error-reporter.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/error-reporter.h index e3bf6acf6ff..1fc66c52cdb 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/error-reporter.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/error-reporter.h @@ -21,7 +21,7 @@ #pragma once -#include "../common.h" +#include #include #include #include diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/evolution-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/evolution-test.c++ index 48fe66bf534..964105a6129 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/evolution-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/evolution-test.c++ @@ -872,7 +872,7 @@ public: } kj::MainBuilder::Validity run() { - // https://github.com/sandstorm-io/capnproto/issues/344 describes an obscure bug in the layout + // https://github.com/capnproto/capnproto/issues/344 describes an obscure bug in the layout // algorithm, the fix for which breaks backwards-compatibility for any schema triggering the // bug. In order to avoid silently breaking protocols, we are temporarily throwing an exception // in cases where this bug would have occurred, so that people can decide what to do. diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/grammar.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/grammar.capnp.c++ index 65b20e9bf96..f433a1cf14f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/grammar.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/grammar.capnp.c++ @@ -79,7 +79,7 @@ static const uint16_t m_e75816b56529d464[] = {2, 1, 0}; static const uint16_t i_e75816b56529d464[] = {0, 1, 2}; const ::capnp::_::RawSchema s_e75816b56529d464 = { 0xe75816b56529d464, b_e75816b56529d464.words, 66, nullptr, m_e75816b56529d464, - 0, 3, i_e75816b56529d464, nullptr, nullptr, { &s_e75816b56529d464, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_e75816b56529d464, nullptr, nullptr, { &s_e75816b56529d464, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<66> b_991c7a3693d62cf2 = { @@ -156,7 +156,7 @@ static const uint16_t m_991c7a3693d62cf2[] = {2, 1, 0}; static const uint16_t i_991c7a3693d62cf2[] = {0, 1, 2}; const ::capnp::_::RawSchema s_991c7a3693d62cf2 = { 0x991c7a3693d62cf2, b_991c7a3693d62cf2.words, 66, nullptr, m_991c7a3693d62cf2, - 0, 3, i_991c7a3693d62cf2, nullptr, nullptr, { &s_991c7a3693d62cf2, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_991c7a3693d62cf2, nullptr, nullptr, { &s_991c7a3693d62cf2, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<66> b_90f2a60678fd2367 = { @@ -233,7 +233,7 @@ static const uint16_t m_90f2a60678fd2367[] = {2, 1, 0}; static const uint16_t i_90f2a60678fd2367[] = {0, 1, 2}; const ::capnp::_::RawSchema s_90f2a60678fd2367 = { 0x90f2a60678fd2367, b_90f2a60678fd2367.words, 66, nullptr, m_90f2a60678fd2367, - 0, 3, i_90f2a60678fd2367, nullptr, nullptr, { &s_90f2a60678fd2367, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_90f2a60678fd2367, nullptr, nullptr, { &s_90f2a60678fd2367, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<262> b_8e207d4dfe54d0de = { @@ -513,7 +513,7 @@ static const uint16_t m_8e207d4dfe54d0de[] = {13, 11, 10, 15, 9, 3, 14, 6, 12, 2 static const uint16_t i_8e207d4dfe54d0de[] = {0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 8, 9}; const ::capnp::_::RawSchema s_8e207d4dfe54d0de = { 0x8e207d4dfe54d0de, b_8e207d4dfe54d0de.words, 262, d_8e207d4dfe54d0de, m_8e207d4dfe54d0de, - 5, 16, i_8e207d4dfe54d0de, nullptr, nullptr, { &s_8e207d4dfe54d0de, nullptr, nullptr, 0, 0, nullptr } + 5, 16, i_8e207d4dfe54d0de, nullptr, nullptr, { &s_8e207d4dfe54d0de, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_c90246b71adedbaa = { @@ -593,7 +593,7 @@ static const uint16_t m_c90246b71adedbaa[] = {1, 0, 2}; static const uint16_t i_c90246b71adedbaa[] = {0, 1, 2}; const ::capnp::_::RawSchema s_c90246b71adedbaa = { 0xc90246b71adedbaa, b_c90246b71adedbaa.words, 65, d_c90246b71adedbaa, m_c90246b71adedbaa, - 2, 3, i_c90246b71adedbaa, nullptr, nullptr, { &s_c90246b71adedbaa, nullptr, nullptr, 0, 0, nullptr } + 2, 3, i_c90246b71adedbaa, nullptr, nullptr, { &s_c90246b71adedbaa, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<55> b_aee8397040b0df7a = { @@ -663,7 +663,7 @@ static const uint16_t m_aee8397040b0df7a[] = {0, 1}; static const uint16_t i_aee8397040b0df7a[] = {0, 1}; const ::capnp::_::RawSchema s_aee8397040b0df7a = { 0xaee8397040b0df7a, b_aee8397040b0df7a.words, 55, d_aee8397040b0df7a, m_aee8397040b0df7a, - 2, 2, i_aee8397040b0df7a, nullptr, nullptr, { &s_aee8397040b0df7a, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_aee8397040b0df7a, nullptr, nullptr, { &s_aee8397040b0df7a, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_aa28e1400d793359 = { @@ -727,7 +727,7 @@ static const uint16_t m_aa28e1400d793359[] = {1, 0}; static const uint16_t i_aa28e1400d793359[] = {0, 1}; const ::capnp::_::RawSchema s_aa28e1400d793359 = { 0xaa28e1400d793359, b_aa28e1400d793359.words, 49, d_aa28e1400d793359, m_aa28e1400d793359, - 2, 2, i_aa28e1400d793359, nullptr, nullptr, { &s_aa28e1400d793359, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_aa28e1400d793359, nullptr, nullptr, { &s_aa28e1400d793359, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<677> b_96efe787c17e83bb = { @@ -1429,7 +1429,7 @@ static const uint16_t m_96efe787c17e83bb[] = {18, 3, 40, 37, 39, 22, 41, 34, 31, static const uint16_t i_96efe787c17e83bb[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 41, 0, 1, 2, 3, 4, 5, 6, 38}; const ::capnp::_::RawSchema s_96efe787c17e83bb = { 0x96efe787c17e83bb, b_96efe787c17e83bb.words, 677, d_96efe787c17e83bb, m_96efe787c17e83bb, - 12, 42, i_96efe787c17e83bb, nullptr, nullptr, { &s_96efe787c17e83bb, nullptr, nullptr, 0, 0, nullptr } + 12, 42, i_96efe787c17e83bb, nullptr, nullptr, { &s_96efe787c17e83bb, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<67> b_d5e71144af1ce175 = { @@ -1507,7 +1507,7 @@ static const uint16_t m_d5e71144af1ce175[] = {2, 0, 1}; static const uint16_t i_d5e71144af1ce175[] = {0, 1, 2}; const ::capnp::_::RawSchema s_d5e71144af1ce175 = { 0xd5e71144af1ce175, b_d5e71144af1ce175.words, 67, nullptr, m_d5e71144af1ce175, - 0, 3, i_d5e71144af1ce175, nullptr, nullptr, { &s_d5e71144af1ce175, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_d5e71144af1ce175, nullptr, nullptr, { &s_d5e71144af1ce175, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<45> b_d00489d473826290 = { @@ -1567,7 +1567,7 @@ static const uint16_t m_d00489d473826290[] = {0, 1}; static const uint16_t i_d00489d473826290[] = {0, 1}; const ::capnp::_::RawSchema s_d00489d473826290 = { 0xd00489d473826290, b_d00489d473826290.words, 45, d_d00489d473826290, m_d00489d473826290, - 2, 2, i_d00489d473826290, nullptr, nullptr, { &s_d00489d473826290, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_d00489d473826290, nullptr, nullptr, { &s_d00489d473826290, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<53> b_fb5aeed95cdf6af9 = { @@ -1635,7 +1635,7 @@ static const uint16_t m_fb5aeed95cdf6af9[] = {1, 0}; static const uint16_t i_fb5aeed95cdf6af9[] = {0, 1}; const ::capnp::_::RawSchema s_fb5aeed95cdf6af9 = { 0xfb5aeed95cdf6af9, b_fb5aeed95cdf6af9.words, 53, d_fb5aeed95cdf6af9, m_fb5aeed95cdf6af9, - 2, 2, i_fb5aeed95cdf6af9, nullptr, nullptr, { &s_fb5aeed95cdf6af9, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_fb5aeed95cdf6af9, nullptr, nullptr, { &s_fb5aeed95cdf6af9, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<28> b_94099c3f9eb32d6b = { @@ -1672,7 +1672,7 @@ static const ::capnp::_::AlignedData<28> b_94099c3f9eb32d6b = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_94099c3f9eb32d6b = { 0x94099c3f9eb32d6b, b_94099c3f9eb32d6b.words, 28, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_94099c3f9eb32d6b, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_94099c3f9eb32d6b, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<102> b_b3f66e7a79d81bcd = { @@ -1789,7 +1789,7 @@ static const uint16_t m_b3f66e7a79d81bcd[] = {3, 0, 2, 4, 1}; static const uint16_t i_b3f66e7a79d81bcd[] = {0, 1, 4, 2, 3}; const ::capnp::_::RawSchema s_b3f66e7a79d81bcd = { 0xb3f66e7a79d81bcd, b_b3f66e7a79d81bcd.words, 102, d_b3f66e7a79d81bcd, m_b3f66e7a79d81bcd, - 2, 5, i_b3f66e7a79d81bcd, nullptr, nullptr, { &s_b3f66e7a79d81bcd, nullptr, nullptr, 0, 0, nullptr } + 2, 5, i_b3f66e7a79d81bcd, nullptr, nullptr, { &s_b3f66e7a79d81bcd, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<110> b_fffe08a9a697d2a5 = { @@ -1916,7 +1916,7 @@ static const uint16_t m_fffe08a9a697d2a5[] = {2, 3, 5, 0, 4, 1}; static const uint16_t i_fffe08a9a697d2a5[] = {0, 1, 2, 3, 4, 5}; const ::capnp::_::RawSchema s_fffe08a9a697d2a5 = { 0xfffe08a9a697d2a5, b_fffe08a9a697d2a5.words, 110, d_fffe08a9a697d2a5, m_fffe08a9a697d2a5, - 4, 6, i_fffe08a9a697d2a5, nullptr, nullptr, { &s_fffe08a9a697d2a5, nullptr, nullptr, 0, 0, nullptr } + 4, 6, i_fffe08a9a697d2a5, nullptr, nullptr, { &s_fffe08a9a697d2a5, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_e5104515fd88ea47 = { @@ -1982,7 +1982,7 @@ static const uint16_t m_e5104515fd88ea47[] = {0, 1}; static const uint16_t i_e5104515fd88ea47[] = {0, 1}; const ::capnp::_::RawSchema s_e5104515fd88ea47 = { 0xe5104515fd88ea47, b_e5104515fd88ea47.words, 51, d_e5104515fd88ea47, m_e5104515fd88ea47, - 2, 2, i_e5104515fd88ea47, nullptr, nullptr, { &s_e5104515fd88ea47, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_e5104515fd88ea47, nullptr, nullptr, { &s_e5104515fd88ea47, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_89f0c973c103ae96 = { @@ -2062,7 +2062,7 @@ static const uint16_t m_89f0c973c103ae96[] = {2, 1, 0}; static const uint16_t i_89f0c973c103ae96[] = {0, 1, 2}; const ::capnp::_::RawSchema s_89f0c973c103ae96 = { 0x89f0c973c103ae96, b_89f0c973c103ae96.words, 65, d_89f0c973c103ae96, m_89f0c973c103ae96, - 2, 3, i_89f0c973c103ae96, nullptr, nullptr, { &s_89f0c973c103ae96, nullptr, nullptr, 0, 0, nullptr } + 2, 3, i_89f0c973c103ae96, nullptr, nullptr, { &s_89f0c973c103ae96, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<34> b_e93164a80bfe2ccf = { @@ -2111,7 +2111,7 @@ static const uint16_t m_e93164a80bfe2ccf[] = {0}; static const uint16_t i_e93164a80bfe2ccf[] = {0}; const ::capnp::_::RawSchema s_e93164a80bfe2ccf = { 0xe93164a80bfe2ccf, b_e93164a80bfe2ccf.words, 34, d_e93164a80bfe2ccf, m_e93164a80bfe2ccf, - 2, 1, i_e93164a80bfe2ccf, nullptr, nullptr, { &s_e93164a80bfe2ccf, nullptr, nullptr, 0, 0, nullptr } + 2, 1, i_e93164a80bfe2ccf, nullptr, nullptr, { &s_e93164a80bfe2ccf, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_b348322a8dcf0d0c = { @@ -2175,7 +2175,7 @@ static const uint16_t m_b348322a8dcf0d0c[] = {0, 1}; static const uint16_t i_b348322a8dcf0d0c[] = {0, 1}; const ::capnp::_::RawSchema s_b348322a8dcf0d0c = { 0xb348322a8dcf0d0c, b_b348322a8dcf0d0c.words, 49, d_b348322a8dcf0d0c, m_b348322a8dcf0d0c, - 2, 2, i_b348322a8dcf0d0c, nullptr, nullptr, { &s_b348322a8dcf0d0c, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_b348322a8dcf0d0c, nullptr, nullptr, { &s_b348322a8dcf0d0c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<43> b_8f2622208fb358c8 = { @@ -2234,7 +2234,7 @@ static const uint16_t m_8f2622208fb358c8[] = {1, 0}; static const uint16_t i_8f2622208fb358c8[] = {0, 1}; const ::capnp::_::RawSchema s_8f2622208fb358c8 = { 0x8f2622208fb358c8, b_8f2622208fb358c8.words, 43, d_8f2622208fb358c8, m_8f2622208fb358c8, - 3, 2, i_8f2622208fb358c8, nullptr, nullptr, { &s_8f2622208fb358c8, nullptr, nullptr, 0, 0, nullptr } + 3, 2, i_8f2622208fb358c8, nullptr, nullptr, { &s_8f2622208fb358c8, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_d0d1a21de617951f = { @@ -2300,7 +2300,7 @@ static const uint16_t m_d0d1a21de617951f[] = {0, 1}; static const uint16_t i_d0d1a21de617951f[] = {0, 1}; const ::capnp::_::RawSchema s_d0d1a21de617951f = { 0xd0d1a21de617951f, b_d0d1a21de617951f.words, 51, d_d0d1a21de617951f, m_d0d1a21de617951f, - 2, 2, i_d0d1a21de617951f, nullptr, nullptr, { &s_d0d1a21de617951f, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_d0d1a21de617951f, nullptr, nullptr, { &s_d0d1a21de617951f, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<40> b_992a90eaf30235d3 = { @@ -2355,7 +2355,7 @@ static const uint16_t m_992a90eaf30235d3[] = {0}; static const uint16_t i_992a90eaf30235d3[] = {0}; const ::capnp::_::RawSchema s_992a90eaf30235d3 = { 0x992a90eaf30235d3, b_992a90eaf30235d3.words, 40, d_992a90eaf30235d3, m_992a90eaf30235d3, - 2, 1, i_992a90eaf30235d3, nullptr, nullptr, { &s_992a90eaf30235d3, nullptr, nullptr, 0, 0, nullptr } + 2, 1, i_992a90eaf30235d3, nullptr, nullptr, { &s_992a90eaf30235d3, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<42> b_eb971847d617c0b9 = { @@ -2413,7 +2413,7 @@ static const uint16_t m_eb971847d617c0b9[] = {0, 1}; static const uint16_t i_eb971847d617c0b9[] = {0, 1}; const ::capnp::_::RawSchema s_eb971847d617c0b9 = { 0xeb971847d617c0b9, b_eb971847d617c0b9.words, 42, d_eb971847d617c0b9, m_eb971847d617c0b9, - 3, 2, i_eb971847d617c0b9, nullptr, nullptr, { &s_eb971847d617c0b9, nullptr, nullptr, 0, 0, nullptr } + 3, 2, i_eb971847d617c0b9, nullptr, nullptr, { &s_eb971847d617c0b9, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_c6238c7d62d65173 = { @@ -2479,7 +2479,7 @@ static const uint16_t m_c6238c7d62d65173[] = {1, 0}; static const uint16_t i_c6238c7d62d65173[] = {0, 1}; const ::capnp::_::RawSchema s_c6238c7d62d65173 = { 0xc6238c7d62d65173, b_c6238c7d62d65173.words, 51, d_c6238c7d62d65173, m_c6238c7d62d65173, - 2, 2, i_c6238c7d62d65173, nullptr, nullptr, { &s_c6238c7d62d65173, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_c6238c7d62d65173, nullptr, nullptr, { &s_c6238c7d62d65173, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<230> b_9cb9e86e3198037f = { @@ -2724,7 +2724,7 @@ static const uint16_t m_9cb9e86e3198037f[] = {12, 2, 3, 4, 6, 1, 8, 9, 10, 11, 5 static const uint16_t i_9cb9e86e3198037f[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; const ::capnp::_::RawSchema s_9cb9e86e3198037f = { 0x9cb9e86e3198037f, b_9cb9e86e3198037f.words, 230, d_9cb9e86e3198037f, m_9cb9e86e3198037f, - 2, 13, i_9cb9e86e3198037f, nullptr, nullptr, { &s_9cb9e86e3198037f, nullptr, nullptr, 0, 0, nullptr } + 2, 13, i_9cb9e86e3198037f, nullptr, nullptr, { &s_9cb9e86e3198037f, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<34> b_84e4f3f5a807605c = { @@ -2772,7 +2772,7 @@ static const uint16_t m_84e4f3f5a807605c[] = {0}; static const uint16_t i_84e4f3f5a807605c[] = {0}; const ::capnp::_::RawSchema s_84e4f3f5a807605c = { 0x84e4f3f5a807605c, b_84e4f3f5a807605c.words, 34, d_84e4f3f5a807605c, m_84e4f3f5a807605c, - 1, 1, i_84e4f3f5a807605c, nullptr, nullptr, { &s_84e4f3f5a807605c, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_84e4f3f5a807605c, nullptr, nullptr, { &s_84e4f3f5a807605c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -2784,195 +2784,291 @@ namespace capnp { namespace compiler { // LocatedText +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LocatedText::_capnpPrivate::dataWordSize; constexpr uint16_t LocatedText::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LocatedText::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LocatedText::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // LocatedInteger +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LocatedInteger::_capnpPrivate::dataWordSize; constexpr uint16_t LocatedInteger::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LocatedInteger::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LocatedInteger::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // LocatedFloat +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LocatedFloat::_capnpPrivate::dataWordSize; constexpr uint16_t LocatedFloat::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LocatedFloat::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LocatedFloat::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Expression +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Expression::_capnpPrivate::dataWordSize; constexpr uint16_t Expression::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Expression::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Expression::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Expression::Param +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Expression::Param::_capnpPrivate::dataWordSize; constexpr uint16_t Expression::Param::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Expression::Param::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Expression::Param::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Expression::Application +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Expression::Application::_capnpPrivate::dataWordSize; constexpr uint16_t Expression::Application::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Expression::Application::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Expression::Application::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Expression::Member +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Expression::Member::_capnpPrivate::dataWordSize; constexpr uint16_t Expression::Member::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Expression::Member::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Expression::Member::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::BrandParameter +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::BrandParameter::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::BrandParameter::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::BrandParameter::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::BrandParameter::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::AnnotationApplication +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::AnnotationApplication::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::AnnotationApplication::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::AnnotationApplication::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::AnnotationApplication::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::AnnotationApplication::Value +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::AnnotationApplication::Value::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::AnnotationApplication::Value::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::AnnotationApplication::Value::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::AnnotationApplication::Value::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::ParamList +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::ParamList::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::ParamList::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::ParamList::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::ParamList::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Param +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Param::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Param::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Param::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Param::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Param::DefaultValue +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Param::DefaultValue::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Param::DefaultValue::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Param::DefaultValue::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Param::DefaultValue::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Id +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Id::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Id::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Id::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Id::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Using +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Using::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Using::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Using::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Using::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Const +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Const::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Const::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Const::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Const::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Field +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Field::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Field::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Field::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Field::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Field::DefaultValue +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Field::DefaultValue::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Field::DefaultValue::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Field::DefaultValue::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Field::DefaultValue::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Interface +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Interface::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Interface::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Interface::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Interface::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Method +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Method::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Method::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Method::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Method::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Method::Results +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Method::Results::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Method::Results::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Method::Results::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Method::Results::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Annotation +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Annotation::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Annotation::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Annotation::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Annotation::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // ParsedFile +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t ParsedFile::_capnpPrivate::dataWordSize; constexpr uint16_t ParsedFile::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind ParsedFile::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* ParsedFile::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/grammar.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/grammar.capnp.h index 34825b31c2b..fcb7f2d12aa 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/grammar.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/grammar.capnp.h @@ -6,7 +6,9 @@ #include #include -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.c++ index 02a1dab5725..22bd04d66df 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.c++ @@ -186,6 +186,20 @@ Lexer::Lexer(Orphanage orphanageParam, ErrorReporter& errorReporter) initTok(t, loc).setStringLiteral(text); return t; }), + p::transformWithLocation( + sequence(p::exactChar<'`'>(), p::many(p::anyOfChars("\r\n").invert())), + [this](Location loc, kj::Array text) -> Orphan { + // Backtick-quoted line. Note that we assume either `\r` or `\n` is a valid line + // ending (to cover all known line ending formats) but we replace the line ending + // with `\n`. This way, changing the line endings of your source code doesn't affect + // the compiled code. + auto t = orphanage.newOrphan(); + // Append '\n' to the text. + auto out = initTok(t, loc).initStringLiteral(text.size() + 1); + memcpy(out.begin(), text.begin(), text.size()); + out[out.size() - 1] = '\n'; + return t; + }), p::transformWithLocation(p::doubleQuotedHexBinary, [this](Location loc, kj::Array data) -> Orphan { auto t = orphanage.newOrphan(); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.capnp.c++ index 7cad320e270..316255fc91c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.capnp.c++ @@ -211,7 +211,7 @@ static const uint16_t m_91cc55cd57de5419[] = {9, 6, 8, 3, 0, 2, 4, 5, 7, 1}; static const uint16_t i_91cc55cd57de5419[] = {0, 1, 2, 3, 4, 5, 6, 9, 7, 8}; const ::capnp::_::RawSchema s_91cc55cd57de5419 = { 0x91cc55cd57de5419, b_91cc55cd57de5419.words, 195, d_91cc55cd57de5419, m_91cc55cd57de5419, - 1, 10, i_91cc55cd57de5419, nullptr, nullptr, { &s_91cc55cd57de5419, nullptr, nullptr, 0, 0, nullptr } + 1, 10, i_91cc55cd57de5419, nullptr, nullptr, { &s_91cc55cd57de5419, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<119> b_c6725e678d60fa37 = { @@ -345,7 +345,7 @@ static const uint16_t m_c6725e678d60fa37[] = {2, 3, 5, 1, 4, 0}; static const uint16_t i_c6725e678d60fa37[] = {1, 2, 0, 3, 4, 5}; const ::capnp::_::RawSchema s_c6725e678d60fa37 = { 0xc6725e678d60fa37, b_c6725e678d60fa37.words, 119, d_c6725e678d60fa37, m_c6725e678d60fa37, - 2, 6, i_c6725e678d60fa37, nullptr, nullptr, { &s_c6725e678d60fa37, nullptr, nullptr, 0, 0, nullptr } + 2, 6, i_c6725e678d60fa37, nullptr, nullptr, { &s_c6725e678d60fa37, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<38> b_9e69a92512b19d18 = { @@ -397,7 +397,7 @@ static const uint16_t m_9e69a92512b19d18[] = {0}; static const uint16_t i_9e69a92512b19d18[] = {0}; const ::capnp::_::RawSchema s_9e69a92512b19d18 = { 0x9e69a92512b19d18, b_9e69a92512b19d18.words, 38, d_9e69a92512b19d18, m_9e69a92512b19d18, - 1, 1, i_9e69a92512b19d18, nullptr, nullptr, { &s_9e69a92512b19d18, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_9e69a92512b19d18, nullptr, nullptr, { &s_9e69a92512b19d18, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<40> b_a11f97b9d6c73dd4 = { @@ -451,7 +451,7 @@ static const uint16_t m_a11f97b9d6c73dd4[] = {0}; static const uint16_t i_a11f97b9d6c73dd4[] = {0}; const ::capnp::_::RawSchema s_a11f97b9d6c73dd4 = { 0xa11f97b9d6c73dd4, b_a11f97b9d6c73dd4.words, 40, d_a11f97b9d6c73dd4, m_a11f97b9d6c73dd4, - 1, 1, i_a11f97b9d6c73dd4, nullptr, nullptr, { &s_a11f97b9d6c73dd4, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_a11f97b9d6c73dd4, nullptr, nullptr, { &s_a11f97b9d6c73dd4, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -463,35 +463,51 @@ namespace capnp { namespace compiler { // Token +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Token::_capnpPrivate::dataWordSize; constexpr uint16_t Token::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Token::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Token::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Statement +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Statement::_capnpPrivate::dataWordSize; constexpr uint16_t Statement::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Statement::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Statement::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // LexedTokens +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LexedTokens::_capnpPrivate::dataWordSize; constexpr uint16_t LexedTokens::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LexedTokens::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LexedTokens::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // LexedStatements +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LexedStatements::_capnpPrivate::dataWordSize; constexpr uint16_t LexedStatements::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LexedStatements::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LexedStatements::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.capnp.h index a93065a2c29..83930163898 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/lexer.capnp.h @@ -6,7 +6,9 @@ #include #include -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/module-loader.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/module-loader.c++ index 803d39238ed..d72b0fce0c7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/module-loader.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/module-loader.c++ @@ -137,10 +137,14 @@ public: kj::Maybe> readEmbedFromSearchPath(kj::PathPtr path); GlobalErrorReporter& getErrorReporter() { return errorReporter; } + void setFileIdsRequired(bool value) { fileIdsRequired = value; } + bool areFileIdsRequired() { return fileIdsRequired; } + private: GlobalErrorReporter& errorReporter; kj::Vector searchPath; std::unordered_map, FileKeyHash> modules; + bool fileIdsRequired = true; }; class ModuleLoader::ModuleImpl final: public Module { @@ -171,7 +175,7 @@ public: lex(content, statements, *this); auto parsed = orphanage.newOrphan(); - parseFile(statements.getStatements(), parsed.get(), *this); + parseFile(statements.getStatements(), parsed.get(), *this, loader.areFileIdsRequired()); return parsed; } @@ -282,5 +286,9 @@ kj::Maybe ModuleLoader::loadModule(const kj::ReadableDirectory& dir, kj return impl->loadModule(dir, path); } +void ModuleLoader::setFileIdsRequired(bool value) { + return impl->setFileIdsRequired(value); +} + } // namespace compiler } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/module-loader.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/module-loader.h index 5e51ac488b5..da0d6daf298 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/module-loader.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/module-loader.h @@ -38,7 +38,7 @@ class ModuleLoader { explicit ModuleLoader(GlobalErrorReporter& errorReporter); // Create a ModuleLoader that reports error messages to the given reporter. - KJ_DISALLOW_COPY(ModuleLoader); + KJ_DISALLOW_COPY_AND_MOVE(ModuleLoader); ~ModuleLoader() noexcept(false); @@ -49,6 +49,10 @@ class ModuleLoader { // Tries to load a module with the given path inside the given directory. Returns nullptr if the // file doesn't exist. + void setFileIdsRequired(bool value); + // Same as SchemaParser::setFileIdsRequired(). If set false, files will not be required to have + // a top-level file ID; if missing a random one will be assigned. + private: class Impl; kj::Own impl; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/node-translator.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/node-translator.c++ index d7227386fd1..fd2577eb007 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/node-translator.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/node-translator.c++ @@ -110,7 +110,7 @@ public: // from the given offset. The idea is that you just allocated an lgSize-sized field from // an limitLgSize-sized space, such as a newly-added word on the end of the data segment. - KJ_DREQUIRE(limitLgSize <= kj::size(holes)); + KJ_ASSUME(limitLgSize <= kj::size(holes)); while (lgSize < limitLgSize) { KJ_DREQUIRE(holes[lgSize] == 0); @@ -217,7 +217,7 @@ public: } Top() = default; - KJ_DISALLOW_COPY(Top); + KJ_DISALLOW_COPY_AND_MOVE(Top); }; struct Union { @@ -245,7 +245,7 @@ public: kj::Vector pointerLocations; inline Union(StructOrGroup& parent): parent(parent) {} - KJ_DISALLOW_COPY(Union); + KJ_DISALLOW_COPY_AND_MOVE(Union); uint addNewDataLocation(uint lgSize) { // Add a whole new data location to the union with the given size. @@ -433,7 +433,7 @@ public: // exception to alert developers of the problem. // // TODO(cleanup): Once sufficient time has elapsed, remove this assert. - KJ_FAIL_ASSERT("Bad news: Cap'n Proto 0.5.x and previous contained a bug which would cause this schema to be compiled incorrectly. Please see: https://github.com/sandstorm-io/capnproto/issues/344"); + KJ_FAIL_ASSERT("Bad news: Cap'n Proto 0.5.x and previous contained a bug which would cause this schema to be compiled incorrectly. Please see: https://github.com/capnproto/capnproto/issues/344"); } lgSizeUsed = desiredUsage; return true; @@ -452,7 +452,7 @@ public: bool hasMembers = false; inline Group(Union& parent): parent(parent) {} - KJ_DISALLOW_COPY(Group); + KJ_DISALLOW_COPY_AND_MOVE(Group); void addMember() { if (!hasMembers) { @@ -559,7 +559,7 @@ public: bool result = usage.tryExpand( *this, location, oldLgSize, localOldOffset, expansionFactor); if (mustFail && result) { - KJ_FAIL_ASSERT("Bad news: Cap'n Proto 0.5.x and previous contained a bug which would cause this schema to be compiled incorrectly. Please see: https://github.com/sandstorm-io/capnproto/issues/344"); + KJ_FAIL_ASSERT("Bad news: Cap'n Proto 0.5.x and previous contained a bug which would cause this schema to be compiled incorrectly. Please see: https://github.com/capnproto/capnproto/issues/344"); } return result; } @@ -955,7 +955,7 @@ public: explicit StructTranslator(NodeTranslator& translator, ImplicitParams implicitMethodParams) : translator(translator), errorReporter(translator.errorReporter), implicitMethodParams(implicitMethodParams) {} - KJ_DISALLOW_COPY(StructTranslator); + KJ_DISALLOW_COPY_AND_MOVE(StructTranslator); void translate(Void decl, List::Reader members, schema::Node::Builder builder, schema::Node::SourceInfo::Builder sourceInfo) { @@ -1370,7 +1370,7 @@ private: MemberInfo& member = *entry.second; // Make sure the exceptions added relating to - // https://github.com/sandstorm-io/capnproto/issues/344 identify the affected field. + // https://github.com/capnproto/capnproto/issues/344 identify the affected field. KJ_CONTEXT(member.name); if (member.declId.isOrdinal()) { @@ -1841,31 +1841,52 @@ kj::Maybe> ValueTranslator::compileValue(Expression::Reader Orphan result = compileValueInner(src, type); + if (result.getType() == DynamicValue::UNKNOWN) { + // Error already reported. + return nullptr; + } else if (matchesType(src, type, result)) { + return kj::mv(result); + } else { + // If the expected type is a struct, we try matching its first field. + if (type.isStruct()) { + auto structType = type.asStruct(); + auto fields = structType.getFields(); + if (fields.size() > 0) { + auto field = fields[0]; + if (matchesType(src, field.getType(), result)) { + // Success. Wrap in a struct. + auto outer = orphanage.newOrphan(type.asStruct()); + outer.get().adopt(field, kj::mv(result)); + return Orphan(kj::mv(outer)); + } + } + } + + // That didn't work, so this is just a type mismatch. + errorReporter.addErrorOn(src, kj::str("Type mismatch; expected ", makeTypeName(type), ".")); + return nullptr; + } +} + +bool ValueTranslator::matchesType(Expression::Reader src, Type type, Orphan& result) { // compileValueInner() evaluated `src` and only used `type` as a hint in interpreting `src` if // `src`'s type wasn't already obvious. So, now we need to check that the resulting value // actually matches `type`. switch (result.getType()) { case DynamicValue::UNKNOWN: - // Error already reported. - return nullptr; + KJ_UNREACHABLE; case DynamicValue::VOID: - if (type.isVoid()) { - return kj::mv(result); - } - break; + return type.isVoid(); case DynamicValue::BOOL: - if (type.isBool()) { - return kj::mv(result); - } - break; + return type.isBool(); case DynamicValue::INT: { int64_t value = result.getReader().as(); if (value < 0) { - int64_t minValue = 1; + int64_t minValue; switch (type.which()) { case schema::Type::INT8: minValue = (int8_t)kj::minValue; break; case schema::Type::INT16: minValue = (int16_t)kj::minValue; break; @@ -1882,21 +1903,20 @@ kj::Maybe> ValueTranslator::compileValue(Expression::Reader minValue = (int64_t)kj::minValue; break; - default: break; + default: return false; } - if (minValue == 1) break; if (value < minValue) { errorReporter.addErrorOn(src, "Integer value out of range."); result = minValue; } - return kj::mv(result); + return true; } } KJ_FALLTHROUGH; // value is positive, so we can just go on to the uint case below. case DynamicValue::UINT: { - uint64_t maxValue = 0; + uint64_t maxValue; switch (type.which()) { case schema::Type::INT8: maxValue = (int8_t)kj::maxValue; break; case schema::Type::INT16: maxValue = (int16_t)kj::maxValue; break; @@ -1913,76 +1933,62 @@ kj::Maybe> ValueTranslator::compileValue(Expression::Reader maxValue = (uint64_t)kj::maxValue; break; - default: break; + default: return false; } - if (maxValue == 0) break; if (result.getReader().as() > maxValue) { errorReporter.addErrorOn(src, "Integer value out of range."); result = maxValue; } - return kj::mv(result); + return true; } case DynamicValue::FLOAT: - if (type.isFloat32() || type.isFloat64()) { - return kj::mv(result); - } - break; + return type.isFloat32() || type.isFloat64(); case DynamicValue::TEXT: - if (type.isText()) { - return kj::mv(result); - } - break; + return type.isText(); case DynamicValue::DATA: - if (type.isData()) { - return kj::mv(result); - } - break; + return type.isData(); case DynamicValue::LIST: if (type.isList()) { - if (result.getReader().as().getSchema() == type.asList()) { - return kj::mv(result); - } + return result.getReader().as().getSchema() == type.asList(); } else if (type.isAnyPointer()) { switch (type.whichAnyPointerKind()) { case schema::Type::AnyPointer::Unconstrained::ANY_KIND: case schema::Type::AnyPointer::Unconstrained::LIST: - return kj::mv(result); + return true; case schema::Type::AnyPointer::Unconstrained::STRUCT: case schema::Type::AnyPointer::Unconstrained::CAPABILITY: - break; + return false; } + KJ_UNREACHABLE; + } else { + return false; } - break; case DynamicValue::ENUM: - if (type.isEnum()) { - if (result.getReader().as().getSchema() == type.asEnum()) { - return kj::mv(result); - } - } - break; + return type.isEnum() && + result.getReader().as().getSchema() == type.asEnum(); case DynamicValue::STRUCT: if (type.isStruct()) { - if (result.getReader().as().getSchema() == type.asStruct()) { - return kj::mv(result); - } + return result.getReader().as().getSchema() == type.asStruct(); } else if (type.isAnyPointer()) { switch (type.whichAnyPointerKind()) { case schema::Type::AnyPointer::Unconstrained::ANY_KIND: case schema::Type::AnyPointer::Unconstrained::STRUCT: - return kj::mv(result); + return true; case schema::Type::AnyPointer::Unconstrained::LIST: case schema::Type::AnyPointer::Unconstrained::CAPABILITY: - break; + return false; } + KJ_UNREACHABLE; + } else { + return false; } - break; case DynamicValue::CAPABILITY: KJ_FAIL_ASSERT("Interfaces can't have literal values."); @@ -1991,8 +1997,7 @@ kj::Maybe> ValueTranslator::compileValue(Expression::Reader KJ_FAIL_ASSERT("AnyPointers can't have literal values."); } - errorReporter.addErrorOn(src, kj::str("Type mismatch; expected ", makeTypeName(type), ".")); - return nullptr; + KJ_UNREACHABLE; } Orphan ValueTranslator::compileValueInner(Expression::Reader src, Type type) { @@ -2183,9 +2188,26 @@ void ValueTranslator::fillStructValue(DynamicStruct::Builder builder, break; case schema::Field::GROUP: + auto groupBuilder = builder.init(*field).as(); if (value.isTuple()) { - fillStructValue(builder.init(*field).as(), value.getTuple()); + fillStructValue(groupBuilder, value.getTuple()); } else { + auto groupFields = groupBuilder.getSchema().getFields(); + if (groupFields.size() > 0) { + auto groupField = groupFields[0]; + + // Call compileValueInner() using the group's type as `type`. Since we already + // established `value` is not a tuple, this will only return a valid result if + // the value has unambiguous type. + auto result = compileValueInner(value, field->getType()); + + // Does it match the first field? + if (matchesType(value, groupField.getType(), result)) { + groupBuilder.adopt(groupField, kj::mv(result)); + break; + } + } + errorReporter.addErrorOn(value, "Type mismatch; expected group."); } break; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/node-translator.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/node-translator.h index 6365fa005ea..7562b086373 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/node-translator.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/node-translator.h @@ -169,7 +169,7 @@ class NodeTranslator { void compileBootstrapValue( Expression::Reader source, schema::Type::Reader type, schema::Value::Builder target, kj::Maybe typeScope = nullptr); - // Calls compileValue() if this value should be interpreted at bootstrap time. Otheriwse, + // Calls compileValue() if this value should be interpreted at bootstrap time. Otherwise, // adds the value to `unfinishedValues` for later evaluation. // // If `type` comes from some other node, `typeScope` is the schema for that node. Otherwise the @@ -215,7 +215,8 @@ class ValueTranslator { Orphanage orphanage; Orphan compileValueInner(Expression::Reader src, Type type); - // Helper for compileValue(). + bool matchesType(Expression::Reader src, Type type, Orphan& result); + // Helpers for compileValue(). kj::String makeNodeName(Schema node); kj::String makeTypeName(Type type); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/parser.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/parser.c++ index eadb279577f..91a8c5cf4c0 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/parser.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/parser.c++ @@ -59,6 +59,7 @@ uint64_t generateRandomId() { #else int fd; KJ_SYSCALL(fd = open("/dev/urandom", O_RDONLY)); + KJ_DEFER(close(fd)); ssize_t n; KJ_SYSCALL(n = read(fd, &result, sizeof(result)), "/dev/urandom"); @@ -69,7 +70,7 @@ uint64_t generateRandomId() { } void parseFile(List::Reader statements, ParsedFile::Builder result, - ErrorReporter& errorReporter) { + ErrorReporter& errorReporter, bool requiresId) { CapnpParser parser(Orphanage::getForMessageContaining(result), errorReporter); kj::Vector> decls(statements.size()); @@ -110,7 +111,7 @@ void parseFile(List::Reader statements, ParsedFile::Builder result, // Don't report missing ID if there was a parse error, because quite often the parse error // prevents us from parsing the ID even though it is actually there. - if (!errorReporter.hadErrors()) { + if (requiresId && !errorReporter.hadErrors()) { errorReporter.addError(0, 0, kj::str("File does not declare an ID. I've generated one for you. Add this line to " "your file: @0x", kj::hex(id), ";")); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/parser.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/parser.h index f881a0d9345..7c798b2742c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/parser.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/parser.h @@ -33,7 +33,7 @@ namespace capnp { namespace compiler { void parseFile(List::Reader statements, ParsedFile::Builder result, - ErrorReporter& errorReporter); + ErrorReporter& errorReporter, bool requiresId); // Parse a list of statements to build a ParsedFile. // // If any errors are reported, then the output is not usable. However, it may be passed on through @@ -64,7 +64,7 @@ class CapnpParser { ~CapnpParser() noexcept(false); - KJ_DISALLOW_COPY(CapnpParser); + KJ_DISALLOW_COPY_AND_MOVE(CapnpParser); using ParserInput = kj::parse::IteratorInput::Reader::Iterator>; struct DeclParserResult; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/type-id.h b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/type-id.h index a450524551f..5968a1762d9 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/type-id.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/compiler/type-id.h @@ -25,6 +25,8 @@ #include #include +CAPNP_BEGIN_HEADER + namespace capnp { namespace compiler { @@ -40,3 +42,5 @@ uint64_t generateMethodParamsId(uint64_t parentId, uint16_t methodOrdinal, bool } // namespace compiler } // namespace capnp + +CAPNP_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic-capability.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic-capability.c++ index 5a5cb3570ba..81a4ed3540f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic-capability.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic-capability.c++ @@ -40,8 +40,11 @@ Request DynamicCapability::Client::newRequest( auto paramType = method.getParamType(); auto resultType = method.getResultType(); + CallHints hints; + hints.noPromisePipelining = !resultType.mayContainCapabilities(); + auto typeless = hook->newCall( - methodInterface.getProto().getId(), method.getIndex(), sizeHint); + methodInterface.getProto().getId(), method.getIndex(), sizeHint, hints); return Request( typeless.getAs(paramType), kj::mv(typeless.hook), resultType); @@ -63,7 +66,8 @@ Capability::Server::DispatchCallResult DynamicCapability::Server::dispatchCall( return { call(method, CallContext(*context.hook, method.getParamType(), resultType)), - resultType.isStreamResult() + resultType.isStreamResult(), + options.allowCancellation }; } else { return internalUnimplemented( diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic.c++ index 5983db47b1c..59e5d1268fb 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic.c++ @@ -180,7 +180,7 @@ DynamicValue::Reader DynamicStruct::Reader::get(StructSchema::Field field) const case schema::Field::SLOT: { auto slot = proto.getSlot(); - // Note that the default value might be "anyPointer" even if the type is some poniter type + // Note that the default value might be "anyPointer" even if the type is some pointer type // *other than* anyPointer. This happens with generics -- the field is actually a generic // parameter that has been bound, but the default value was of course compiled without any // binding available. @@ -272,7 +272,7 @@ DynamicValue::Builder DynamicStruct::Builder::get(StructSchema::Field field) { case schema::Field::SLOT: { auto slot = proto.getSlot(); - // Note that the default value might be "anyPointer" even if the type is some poniter type + // Note that the default value might be "anyPointer" even if the type is some pointer type // *other than* anyPointer. This happens with generics -- the field is actually a generic // parameter that has been bound, but the default value was of course compiled without any // binding available. @@ -1573,12 +1573,12 @@ DynamicValue::Builder::Builder(Builder& other) { // Unfortunately canMemcpy() doesn't work on these types due to the use of // DisallowConstCopy, but __has_trivial_destructor should detect if any of these types // become non-trivial. - static_assert(__has_trivial_destructor(Text::Builder) && - __has_trivial_destructor(Data::Builder) && - __has_trivial_destructor(DynamicList::Builder) && - __has_trivial_destructor(DynamicEnum) && - __has_trivial_destructor(DynamicStruct::Builder) && - __has_trivial_destructor(AnyPointer::Builder), + static_assert(KJ_HAS_TRIVIAL_DESTRUCTOR(Text::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(Data::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicList::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicEnum) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicStruct::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(AnyPointer::Builder), "Assumptions here don't hold."); break; @@ -1607,12 +1607,12 @@ DynamicValue::Builder::Builder(Builder&& other) noexcept { // Unfortunately __has_trivial_copy doesn't work on these types due to the use of // DisallowConstCopy, but __has_trivial_destructor should detect if any of these types // become non-trivial. - static_assert(__has_trivial_destructor(Text::Builder) && - __has_trivial_destructor(Data::Builder) && - __has_trivial_destructor(DynamicList::Builder) && - __has_trivial_destructor(DynamicEnum) && - __has_trivial_destructor(DynamicStruct::Builder) && - __has_trivial_destructor(AnyPointer::Builder), + static_assert(KJ_HAS_TRIVIAL_DESTRUCTOR(Text::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(Data::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicList::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicEnum) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicStruct::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(AnyPointer::Builder), "Assumptions here don't hold."); break; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic.h b/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic.h index 3fd6cf2d76f..8aab1f7ad98 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/dynamic.h @@ -524,7 +524,16 @@ class DynamicCapability::Server: public Capability::Server { public: typedef DynamicCapability Serves; + struct Options { + bool allowCancellation = false; + // See the `allowCancellation` annotation defined in `c++.capnp`. + // + // This option applies to all calls made to this server object. The annotation in the schema + // is NOT used for dynamic servers. + }; + Server(InterfaceSchema schema): schema(schema) {} + Server(InterfaceSchema schema, Options options): schema(schema), options(options) {} virtual kj::Promise call(InterfaceSchema::Method method, CallContext context) = 0; @@ -536,6 +545,7 @@ class DynamicCapability::Server: public Capability::Server { private: InterfaceSchema schema; + Options options; }; template <> @@ -584,7 +594,6 @@ class CallContext: public kj::DisallowConstCopy { Orphanage getResultsOrphanage(kj::Maybe sizeHint = nullptr); template kj::Promise tailCall(Request&& tailRequest); - void allowCancellation(); StructSchema getParamsType() const { return paramType; } StructSchema getResultsType() const { return resultType; } @@ -1658,9 +1667,6 @@ inline kj::Promise CallContext::tailCall( Request&& tailRequest) { return hook->tailCall(kj::mv(tailRequest.hook)); } -inline void CallContext::allowCancellation() { - hook->allowCancellation(); -} template <> inline DynamicCapability::Client Capability::Client::castAs( @@ -1668,6 +1674,14 @@ inline DynamicCapability::Client Capability::Client::castAs( return DynamicCapability::Client(schema, hook->addRef()); } +template <> +inline DynamicCapability::Client CapabilityServerSet::add( + kj::Own&& server) { + void* ptr = reinterpret_cast(server.get()); + auto schema = server->getSchema(); + return addInternal(kj::mv(server), ptr).castAs(schema); +} + // ------------------------------------------------------------------- template diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/encoding-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/encoding-test.c++ index 6b71c5e1c22..84875015050 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/encoding-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/encoding-test.c++ @@ -1738,6 +1738,14 @@ TEST(Encoding, GlobalConstants) { EXPECT_EQ("structlist 2", listReader[1].getTextField()); EXPECT_EQ("structlist 3", listReader[2].getTextField()); } + + kj::StringPtr expected = + "foo bar baz\n" + "\"qux\" `corge` \'grault\'\n" + "regular\"quoted\"line" + "garply\\nwaldo\\tfred\\\"plugh\\\"xyzzy\\\'thud\n"; + + EXPECT_EQ(expected, test::BLOCK_TEXT); } TEST(Encoding, Embeds) { @@ -2019,6 +2027,81 @@ KJ_TEST("list.setWithCaveats(i, list[i]) doesn't corrupt contents") { checkTestMessage(list[1]); } +KJ_TEST("Downgrade pointer-list from struct-list") { + // Test that downgrading a list-of-structs to a list-of-pointers (where the relevant pointer is + // the struct's first pointer) works as advertised. + + MallocMessageBuilder builder; + auto root = builder.initRoot(); + + { + auto list = root.getAnyPointerField().initAs>(2); + initTestMessage(list[0]); + list[1].setTextField("hello"); + } + + { + auto list = root.asReader().getAnyPointerField().getAs>(); + KJ_ASSERT(list.size() == 2); + KJ_EXPECT(list[0] == "foo"); + KJ_EXPECT(list[1] == "hello"); + } +} + +KJ_TEST("Copying ListList downgraded from ListStruct does not get corrupted") { + // Test written by David Renshaw to demonstrate CVE-??? + + AlignedData<10> data = {{ + // struct, 1 pointer + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + // list, inline composite. 4 words. + 0x01, 0x00, 0x00, 0x00, 0x27, 0x00, 0x00, 0x00, + + // one element, 3 data words, 1 pointer. + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x00, + + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, // data section + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, // data section + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, // data section + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // null struct pointer + + // bad bytes that shouldn't be visible from the root of the message + 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, + 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, + + // bug can cause this word to be read as the list element struct pointer + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + }}; + + kj::ArrayPtr segments[1] = { + // Only take the first 7 words. The last three words above should not be accessible + // from these segments. + kj::arrayPtr(data.words, 7) + }; + + SegmentArrayMessageReader reader(kj::arrayPtr(segments, 1)); + auto readerRoot = reader.getRoot(); + auto listList = readerRoot.getAnyPointerField().getAs>>(); + EXPECT_EQ(listList.size(), 1); + + MallocMessageBuilder builder; + auto root = builder.initRoot(); + + root.getAnyPointerField().setAs>>(listList); + + auto outputSegments = builder.getSegmentsForOutput(); + ASSERT_EQ(outputSegments.size(), 1); + + auto inputBytes = segments[0].asBytes(); + auto outputBytes = outputSegments[0].asBytes(); + + ASSERT_EQ(outputBytes, inputBytes); + // Should be equal. Instead, we see that outputBytes includes the (copied) + // out-of-bounds 0xbb bytes from `data` above, which should be impossible. +} + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/ez-rpc.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/ez-rpc.c++ index ed402d4fb66..5871c77bf4c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/ez-rpc.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/ez-rpc.c++ @@ -177,10 +177,10 @@ Capability::Client EzRpcClient::importCap(kj::StringPtr name) { KJ_IF_MAYBE(client, impl->clientContext) { return client->get()->restore(name); } else { - return impl->setupPromise.addBranch().then(kj::mvCapture(kj::heapString(name), - [this](kj::String&& name) { + return impl->setupPromise.addBranch().then( + [this,name=kj::heapString(name)]() { return KJ_ASSERT_NONNULL(impl->clientContext)->restore(name); - })); + }); } } @@ -260,13 +260,11 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer, portPromise = paf.promise.fork(); tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort) - .then(kj::mvCapture(paf.fulfiller, - [this, readerOpts](kj::Own>&& portFulfiller, - kj::Own&& addr) { + .then([this, portFulfiller=kj::mv(paf.fulfiller), readerOpts](kj::Own&& addr) mutable { auto listener = addr->listen(); portFulfiller->fulfill(listener->getPort()); acceptLoop(kj::mv(listener), readerOpts); - }))); + })); } Impl(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize, @@ -290,9 +288,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer, void acceptLoop(kj::Own&& listener, ReaderOptions readerOpts) { auto ptr = listener.get(); - tasks.add(ptr->accept().then(kj::mvCapture(kj::mv(listener), - [this, readerOpts](kj::Own&& listener, - kj::Own&& connection) { + tasks.add(ptr->accept().then([this, listener=kj::mv(listener), readerOpts](kj::Own&& connection) mutable { acceptLoop(kj::mv(listener), readerOpts); auto server = kj::heap(kj::mv(connection), *this, readerOpts); @@ -300,7 +296,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer, // Arrange to destroy the server context when all references are gone, or when the // EzRpcServer is destroyed (which will destroy the TaskSet). tasks.add(server->network.onDisconnect().attach(kj::mv(server))); - }))); + })); } Capability::Client restore(AnyPointer::Reader objectId) override { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/generated-header-support.h b/libs/EXTERNAL/capnproto/c++/src/capnp/generated-header-support.h index 21f73126e7a..3c5b65665d5 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/generated-header-support.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/generated-header-support.h @@ -137,7 +137,7 @@ struct BrandBindingFor_, Kind::LIST> { template struct BrandBindingFor_ { static constexpr RawBrandedSchema::Binding get(uint16_t listDepth) { - return { 15, listDepth, nullptr }; + return { 15, listDepth, &rawSchema().defaultBrand }; } }; @@ -209,7 +209,7 @@ template class ConstStruct { public: ConstStruct() = delete; - KJ_DISALLOW_COPY(ConstStruct); + KJ_DISALLOW_COPY_AND_MOVE(ConstStruct); inline explicit constexpr ConstStruct(const word* ptr): ptr(ptr) {} inline typename T::Reader get() const { @@ -228,7 +228,7 @@ template class ConstList { public: ConstList() = delete; - KJ_DISALLOW_COPY(ConstList); + KJ_DISALLOW_COPY_AND_MOVE(ConstList); inline explicit constexpr ConstList(const word* ptr): ptr(ptr) {} inline typename List::Reader get() const { @@ -247,7 +247,7 @@ template class ConstText { public: ConstText() = delete; - KJ_DISALLOW_COPY(ConstText); + KJ_DISALLOW_COPY_AND_MOVE(ConstText); inline explicit constexpr ConstText(const word* ptr): ptr(ptr) {} inline Text::Reader get() const { @@ -275,7 +275,7 @@ template class ConstData { public: ConstData() = delete; - KJ_DISALLOW_COPY(ConstData); + KJ_DISALLOW_COPY_AND_MOVE(ConstData); inline explicit constexpr ConstData(const word* ptr): ptr(ptr) {} inline Data::Reader get() const { @@ -334,6 +334,13 @@ inline constexpr uint sizeInWords() { #define CAPNP_AUTO_IF_MSVC(...) __VA_ARGS__ #endif +// TODO(msvc): MSVC does not even expect constexprs to have definitions below C++17. +#if (KJ_CPP_STD < 201703L) && !(defined(_MSC_VER) && !defined(__clang__)) +#define CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL 1 +#else +#define CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL 0 +#endif + #if CAPNP_LITE #define CAPNP_DECLARE_SCHEMA(id) \ @@ -349,12 +356,11 @@ inline constexpr uint sizeInWords() { static inline ::capnp::word const* encodedSchema() { return bp_##id; } \ } -#if _MSC_VER && !defined(__clang__) -// TODO(msvc): MSVC doesn't expect constexprs to have definitions. -#define CAPNP_DEFINE_ENUM(type, id) -#else +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #define CAPNP_DEFINE_ENUM(type, id) \ constexpr uint64_t EnumInfo::typeId +#else +#define CAPNP_DEFINE_ENUM(type, id) #endif #define CAPNP_DECLARE_STRUCT_HEADER(id, dataWordSize_, pointerCount_) \ @@ -380,9 +386,14 @@ inline constexpr uint sizeInWords() { static inline ::capnp::word const* encodedSchema() { return bp_##id; } \ static constexpr ::capnp::_::RawSchema const* schema = &s_##id; \ } + +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #define CAPNP_DEFINE_ENUM(type, id) \ constexpr uint64_t EnumInfo::typeId; \ constexpr ::capnp::_::RawSchema const* EnumInfo::schema +#else +#define CAPNP_DEFINE_ENUM(type, id) +#endif #define CAPNP_DECLARE_STRUCT_HEADER(id, dataWordSize_, pointerCount_) \ struct IsStruct; \ diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/layout.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/layout.c++ index 7fa5b4e85f2..e2fa84df64c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/layout.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/layout.c++ @@ -321,6 +321,13 @@ static_assert(unboundAs(POINTERS * BITS_PER_POINTER / BITS_PER_BYTE / BY sizeof(WirePointer), "BITS_PER_POINTER is wrong."); +#define OUT_OF_BOUNDS_ERROR_DETAIL \ + "This usually indicates that " \ + "the input data was corrupted, used a different encoding than specified (e.g. " \ + "packed vs. non-packed), or was not a Cap'n Proto message to begin with. Note " \ + "that this error is NOT due to a schema mismatch; the input is invalid " \ + "regardless of schema." + namespace { static const union { @@ -484,6 +491,7 @@ struct WireHelpers { return reinterpret_cast(ref); } + KJ_ASSUME(segment != nullptr); word* ptr = segment->allocate(amount); if (ptr == nullptr) { @@ -577,7 +585,8 @@ struct WireHelpers { const word* ptr = ref->farTarget(segment); auto padWords = (ONE + bounded(ref->isDoubleFar())) * POINTER_SIZE_IN_WORDS; KJ_REQUIRE(boundsCheck(segment, ptr, padWords), - "Message contains out-of-bounds far pointer.") { + "Message contains out-of-bounds far pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return nullptr; } @@ -789,7 +798,8 @@ struct WireHelpers { switch (ref->kind()) { case WirePointer::STRUCT: { KJ_REQUIRE(boundsCheck(segment, ptr, ref->structRef.wordSize()), - "Message contained out-of-bounds struct pointer.") { + "Message contained out-of-bounds struct pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } result.addWords(ref->structRef.wordSize()); @@ -815,7 +825,8 @@ struct WireHelpers { upgradeBound(ref->listRef.elementCount()) * dataBitsPerElement(ref->listRef.elementSize())); KJ_REQUIRE(boundsCheck(segment, ptr, totalWords), - "Message contained out-of-bounds list pointer.") { + "Message contained out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } result.addWords(totalWords); @@ -825,7 +836,8 @@ struct WireHelpers { auto count = ref->listRef.elementCount() * (POINTERS / ELEMENTS); KJ_REQUIRE(boundsCheck(segment, ptr, count * WORDS_PER_POINTER), - "Message contained out-of-bounds list pointer.") { + "Message contained out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } @@ -840,7 +852,8 @@ struct WireHelpers { case ElementSize::INLINE_COMPOSITE: { auto wordCount = ref->listRef.inlineCompositeWordCount(); KJ_REQUIRE(boundsCheck(segment, ptr, wordCount + POINTER_SIZE_IN_WORDS), - "Message contained out-of-bounds list pointer.") { + "Message contained out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } @@ -855,7 +868,8 @@ struct WireHelpers { auto actualSize = elementTag->structRef.wordSize() / ELEMENTS * upgradeBound(count); KJ_REQUIRE(actualSize <= wordCount, - "Struct list pointer's elements overran size.") { + "Struct list pointer's elements overran size. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } @@ -1130,7 +1144,7 @@ struct WireHelpers { word* oldPtr = followFars(oldRef, refTarget, oldSegment); KJ_REQUIRE(oldRef->kind() == WirePointer::STRUCT, - "Message contains non-struct pointer where struct pointer was expected.") { + "Schema mismatch: Message contains non-struct pointer where struct pointer was expected.") { goto useDefault; } @@ -1272,7 +1286,7 @@ struct WireHelpers { word* ptr = followFars(ref, origRefTarget, segment); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Called getWritableListPointer() but existing pointer is not a list.") { + "Schema mismatch: Called getWritableListPointer() but existing pointer is not a list.") { goto useDefault; } @@ -1300,8 +1314,8 @@ struct WireHelpers { case ElementSize::BIT: KJ_FAIL_REQUIRE( - "Found struct list where bit list was expected; upgrading boolean lists to structs " - "is no longer supported.") { + "Schema mismatch: Found struct list where bit list was expected; upgrading boolean " + "lists to structs is no longer supported.") { goto useDefault; } break; @@ -1311,14 +1325,14 @@ struct WireHelpers { case ElementSize::FOUR_BYTES: case ElementSize::EIGHT_BYTES: KJ_REQUIRE(dataSize >= ONE * WORDS, - "Existing list value is incompatible with expected type.") { + "Schema mismatch: Existing list value is incompatible with expected type.") { goto useDefault; } break; case ElementSize::POINTER: KJ_REQUIRE(pointerCount >= ONE * POINTERS, - "Existing list value is incompatible with expected type.") { + "Schema mismatch: Existing list value is incompatible with expected type.") { goto useDefault; } // Adjust the pointer to point at the reference segment. @@ -1341,20 +1355,20 @@ struct WireHelpers { if (elementSize == ElementSize::BIT) { KJ_REQUIRE(oldSize == ElementSize::BIT, - "Found non-bit list where bit list was expected.") { + "Schema mismatch: Found non-bit list where bit list was expected.") { goto useDefault; } } else { KJ_REQUIRE(oldSize != ElementSize::BIT, - "Found bit list where non-bit list was expected.") { + "Schema mismatch: Found bit list where non-bit list was expected.") { goto useDefault; } KJ_REQUIRE(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS, - "Existing list value is incompatible with expected type.") { + "Schema mismatch: Existing list value is incompatible with expected type.") { goto useDefault; } KJ_REQUIRE(pointerCount >= pointersPerElement(elementSize) * ELEMENTS, - "Existing list value is incompatible with expected type.") { + "Schema mismatch: Existing list value is incompatible with expected type.") { goto useDefault; } } @@ -1392,7 +1406,8 @@ struct WireHelpers { word* ptr = followFars(ref, origRefTarget, segment); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Called getWritableListPointerAnySize() but existing pointer is not a list.") { + "Schema mismatch: Called getWritableListPointerAnySize() but existing pointer is not a " + "list.") { goto useDefault; } @@ -1448,7 +1463,8 @@ struct WireHelpers { word* oldPtr = followFars(oldRef, origRefTarget, oldSegment); KJ_REQUIRE(oldRef->kind() == WirePointer::LIST, - "Called getList{Field,Element}() but existing pointer is not a list.") { + "Schema mismatch: Called getList{Field,Element}() but existing pointer is not a " + "list.") { goto useDefault; } @@ -1543,8 +1559,8 @@ struct WireHelpers { // Upgrading to an inline composite list. KJ_REQUIRE(oldSize != ElementSize::BIT, - "Found bit list where struct list was expected; upgrading boolean lists to structs " - "is no longer supported.") { + "Schema mismatch: Found bit list where struct list was expected; upgrading boolean " + "lists to structs is no longer supported.") { goto useDefault; } @@ -1662,11 +1678,12 @@ struct WireHelpers { byte* bptr = reinterpret_cast(ptr); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Called getText{Field,Element}() but existing pointer is not a list.") { + "Schema mismatch: Called getText{Field,Element}() but existing pointer is not a list.") { goto useDefault; } KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, - "Called getText{Field,Element}() but existing list pointer is not byte-sized.") { + "Schema mismatch: Called getText{Field,Element}() but existing list pointer is not " + "byte-sized.") { goto useDefault; } @@ -1733,11 +1750,12 @@ struct WireHelpers { word* ptr = followFars(ref, refTarget, segment); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Called getData{Field,Element}() but existing pointer is not a list.") { + "Schema mismatch: Called getData{Field,Element}() but existing pointer is not a list.") { goto useDefault; } KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, - "Called getData{Field,Element}() but existing list pointer is not byte-sized.") { + "Schema mismatch: Called getData{Field,Element}() but existing list pointer is not " + "byte-sized.") { goto useDefault; } @@ -1964,7 +1982,8 @@ struct WireHelpers { } KJ_REQUIRE(boundsCheck(srcSegment, ptr, src->structRef.wordSize()), - "Message contained out-of-bounds struct pointer.") { + "Message contained out-of-bounds struct pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } return setStructPointer(dstSegment, dstCapTable, dst, @@ -1988,7 +2007,8 @@ struct WireHelpers { const WirePointer* tag = reinterpret_cast(ptr); KJ_REQUIRE(boundsCheck(srcSegment, ptr, wordCount + POINTER_SIZE_IN_WORDS), - "Message contains out-of-bounds list pointer.") { + "Message contains out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2031,7 +2051,8 @@ struct WireHelpers { auto wordCount = roundBitsUpToWords(upgradeBound(elementCount) * step); KJ_REQUIRE(boundsCheck(srcSegment, ptr, wordCount), - "Message contains out-of-bounds list pointer.") { + "Message contains out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2175,12 +2196,14 @@ struct WireHelpers { } KJ_REQUIRE(ref->kind() == WirePointer::STRUCT, - "Message contains non-struct pointer where struct pointer was expected.") { + "Schema mismatch: Message contains non-struct pointer where struct pointer" + "was expected.") { goto useDefault; } KJ_REQUIRE(boundsCheck(segment, ptr, ref->structRef.wordSize()), - "Message contained out-of-bounds struct pointer.") { + "Message contained out-of-bounds struct pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2209,7 +2232,8 @@ struct WireHelpers { return brokenCapFactory->newNullCap(); } else if (!ref->isCapability()) { KJ_FAIL_REQUIRE( - "Message contains non-capability pointer where capability pointer was expected.") { + "Schema mismatch: Message contains non-capability pointer where capability pointer was " + "expected.") { break; } return brokenCapFactory->newBrokenCap( @@ -2263,7 +2287,8 @@ struct WireHelpers { } KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Message contains non-list pointer where list pointer was expected.") { + "Schema mismatch: Message contains non-list pointer where list pointer was " + "expected.") { goto useDefault; } @@ -2275,7 +2300,8 @@ struct WireHelpers { const WirePointer* tag = reinterpret_cast(ptr); KJ_REQUIRE(boundsCheck(segment, ptr, wordCount + POINTER_SIZE_IN_WORDS), - "Message contains out-of-bounds list pointer.") { + "Message contains out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2327,18 +2353,16 @@ struct WireHelpers { case ElementSize::FOUR_BYTES: case ElementSize::EIGHT_BYTES: KJ_REQUIRE(tag->structRef.dataSize.get() > ZERO * WORDS, - "Expected a primitive list, but got a list of pointer-only structs.") { + "Schema mismatch: Expected a primitive list, but got a list of pointer-only " + "structs.") { goto useDefault; } break; case ElementSize::POINTER: - // We expected a list of pointers but got a list of structs. Assuming the first field - // in the struct is the pointer we were looking for, we want to munge the pointer to - // point at the first element's pointer section. - ptr += tag->structRef.dataSize.get(); KJ_REQUIRE(tag->structRef.ptrCount.get() > ZERO * POINTERS, - "Expected a pointer list, but got a list of data-only structs.") { + "Schema mismatch: Expected a pointer list, but got a list of data-only " + "structs.") { goto useDefault; } break; @@ -2364,7 +2388,8 @@ struct WireHelpers { auto wordCount = roundBitsUpToWords(upgradeBound(elementCount) * step); KJ_REQUIRE(boundsCheck(segment, ptr, wordCount), - "Message contains out-of-bounds list pointer.") { + "Message contains out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2397,11 +2422,11 @@ struct WireHelpers { pointersPerElement(expectedElementSize) * ELEMENTS; KJ_REQUIRE(expectedDataBitsPerElement <= dataSize, - "Message contained list with incompatible element type.") { + "Schema mismatch: Message contained list with incompatible element type.") { goto useDefault; } KJ_REQUIRE(expectedPointersPerElement <= pointerCount, - "Message contained list with incompatible element type.") { + "Schema mismatch: Message contained list with incompatible element type.") { goto useDefault; } } @@ -2436,17 +2461,19 @@ struct WireHelpers { auto size = ref->listRef.elementCount() * (ONE * BYTES / ELEMENTS); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Message contains non-list pointer where text was expected.") { + "Schema mismatch: Message contains non-list pointer where text was expected.") { goto useDefault; } KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, - "Message contains list pointer of non-bytes where text was expected.") { + "Schema mismatch: Message contains list pointer of non-bytes where text was " + "expected.") { goto useDefault; } KJ_REQUIRE(boundsCheck(segment, ptr, roundBytesUpToWords(size)), - "Message contained out-of-bounds text pointer.") { + "Message contained out-of-bounds text pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2494,17 +2521,19 @@ struct WireHelpers { auto size = ref->listRef.elementCount() * (ONE * BYTES / ELEMENTS); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Message contains non-list pointer where data was expected.") { + "Schema mismatch: Message contains non-list pointer where data was expected.") { goto useDefault; } KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, - "Message contains list pointer of non-bytes where data was expected.") { + "Schema mismatch: Message contains list pointer of non-bytes where data was " + "expected.") { goto useDefault; } KJ_REQUIRE(boundsCheck(segment, ptr, roundBytesUpToWords(size)), - "Message contained out-of-bounds data pointer.") { + "Message contained out-of-bounds data pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -3092,7 +3121,7 @@ ListBuilder ListBuilder::imbue(CapTableBuilder* capTable) { Text::Reader ListReader::asText() { KJ_REQUIRE(structDataSize == G(8) * BITS && structPointerCount == ZERO * POINTERS, - "Expected Text, got list of non-bytes.") { + "Schema mismatch: Expected Text, got list of non-bytes.") { return Text::Reader(); } @@ -3114,7 +3143,7 @@ Text::Reader ListReader::asText() { Data::Reader ListReader::asData() { KJ_REQUIRE(structDataSize == G(8) * BITS && structPointerCount == ZERO * POINTERS, - "Expected Text, got list of non-bytes.") { + "Schema mismatch: Expected Text, got list of non-bytes.") { return Data::Reader(); } @@ -3123,7 +3152,7 @@ Data::Reader ListReader::asData() { kj::ArrayPtr ListReader::asRawBytes() const { KJ_REQUIRE(structPointerCount == ZERO * POINTERS, - "Expected data only, got pointers.") { + "Schema mismatch: Expected data only, got pointers.") { return kj::ArrayPtr(); } @@ -3661,7 +3690,7 @@ bool OrphanBuilder::truncate(ElementCount uncheckedSize, bool isText) { return size == ZERO * ELEMENTS; } - KJ_REQUIRE(ref->kind() == WirePointer::LIST, "Can't truncate non-list.") { + KJ_REQUIRE(ref->kind() == WirePointer::LIST, "Schema mismatch: Can't truncate non-list.") { return false; } diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/layout.h b/libs/EXTERNAL/capnproto/c++/src/capnp/layout.h index c8d533cff1b..7a27f68a1f8 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/layout.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/layout.h @@ -1227,8 +1227,12 @@ inline Void ListReader::getDataElement(ElementCount index) const { } inline PointerReader ListReader::getPointerElement(ElementCount index) const { + // If the list elements have data sections we need to skip those. Note that for pointers to be + // present at all (which already must be true if we get here), then `structDataSize` must be a + // whole number of words, so we don't have to worry about unaligned reads here. + auto offset = structDataSize / BITS_PER_BYTE; return PointerReader(segment, capTable, reinterpret_cast( - ptr + upgradeBound(index) * step / BITS_PER_BYTE), nestingLimit); + ptr + offset + upgradeBound(index) * step / BITS_PER_BYTE), nestingLimit); } // ------------------------------------------------------------------- diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/membrane-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/membrane-test.c++ index 4fa928e45eb..9f3bc215836 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/membrane-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/membrane-test.c++ @@ -90,7 +90,6 @@ protected: } kj::Promise waitForever(WaitForeverContext context) override { - context.allowCancellation(); return kj::NEVER_DONE; } }; @@ -129,6 +128,8 @@ public: }); } + bool shouldResolveBeforeRedirecting() override { return true; } + private: kj::Maybe> revokePromise; }; @@ -276,6 +277,33 @@ KJ_TEST("apply membrane using copyOutOfMembrane() on AnyPointer") { }, "inside", "inbound", "inside", "inside"); } +KJ_TEST("MembraneHook::whenMoreResolved returns same value even when called concurrently.") { + TestEnv env; + + auto paf = kj::newPromiseAndFulfiller(); + test::TestMembrane::Client promCap(kj::mv(paf.promise)); + + auto prom = promCap.whenResolved(); + prom = prom.then([promCap = kj::mv(promCap), &env]() mutable { + auto membraned = membrane(kj::mv(promCap), env.policy->addRef()); + auto hook = ClientHook::from(membraned); + + auto arr = kj::heapArrayBuilder>>(2); + arr.add(KJ_ASSERT_NONNULL(hook->whenMoreResolved())); + arr.add(KJ_ASSERT_NONNULL(hook->whenMoreResolved())); + + return kj::joinPromises(arr.finish()).attach(kj::mv(hook)); + }).then([](kj::Vector> hooks) { + auto first = hooks[0].get(); + auto second = hooks[1].get(); + KJ_ASSERT(first == second); + }).eagerlyEvaluate(nullptr); + + auto newClient = kj::heap(); + paf.fulfiller->fulfill(kj::mv(newClient)); + prom.wait(env.waitScope); +} + struct TestRpcEnv { kj::EventLoop loop; kj::WaitScope waitScope; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/membrane.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/membrane.c++ index 732f062fd9b..845ff954894 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/membrane.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/membrane.c++ @@ -201,14 +201,14 @@ public: auto onRevoked = policy->onRevoked(); bool reverse = this->reverse; // for capture - auto newPromise = promise.then(kj::mvCapture(policy, - [reverse](kj::Own&& policy, Response&& response) { + auto newPromise = promise.then( + [reverse,policy=kj::mv(policy)](Response&& response) mutable { AnyPointer::Reader reader = response; auto newRespHook = kj::heap( ResponseHook::from(kj::mv(response)), policy->addRef(), reverse); reader = newRespHook->imbue(reader); return Response(reader, kj::mv(newRespHook)); - })); + }); KJ_IF_MAYBE(r, kj::mv(onRevoked)) { newPromise = newPromise.exclusiveJoin(r->then([]() -> Response { @@ -231,6 +231,11 @@ public: return promise; } + AnyPointer::Pipeline sendForPipeline() override { + return AnyPointer::Pipeline(kj::refcounted( + PipelineHook::from(inner->sendForPipeline()), policy->addRef(), reverse)); + } + const void* getBrand() override { return MEMBRANE_BRAND; } @@ -286,10 +291,6 @@ public: return inner->tailCall(MembraneRequestHook::wrap(kj::mv(request), *policy, !reverse)); } - void allowCancellation() override { - inner->allowCancellation(); - } - kj::Promise onTailCall() override { return inner->onTailCall().then([this](AnyPointer::Pipeline&& innerPipeline) { return AnyPointer::Pipeline(kj::refcounted( @@ -324,6 +325,8 @@ private: kj::Maybe results; }; +} // namespace + class MembraneHook final: public ClientHook, public kj::Refcounted { public: MembraneHook(kj::Own&& inner, kj::Own&& policyParam, bool reverse) @@ -335,6 +338,11 @@ public: } } + ~MembraneHook() noexcept(false) { + auto& map = reverse ? policy->reverseWrappers : policy->wrappers; + map.erase(inner.get()); + } + static kj::Own wrap(ClientHook& cap, MembranePolicy& policy, bool reverse) { if (cap.getBrand() == MEMBRANE_BRAND) { auto& otherMembrane = kj::downcast(cap); @@ -350,9 +358,19 @@ public: } } - return ClientHook::from( - reverse ? policy.importExternal(Capability::Client(cap.addRef())) - : policy.exportInternal(Capability::Client(cap.addRef()))); + auto& map = reverse ? policy.reverseWrappers : policy.wrappers; + ClientHook*& slot = map.findOrCreate(&cap, [&]() -> kj::Decay::Entry { + return { &cap, nullptr }; + }); + if (slot == nullptr) { + auto result = ClientHook::from( + reverse ? policy.importExternal(Capability::Client(cap.addRef())) + : policy.exportInternal(Capability::Client(cap.addRef()))); + slot = result; + return result; + } else { + return slot->addRef(); + } } static kj::Own wrap(kj::Own cap, MembranePolicy& policy, bool reverse) { @@ -370,15 +388,26 @@ public: } } - return ClientHook::from( - reverse ? policy.importExternal(Capability::Client(kj::mv(cap))) - : policy.exportInternal(Capability::Client(kj::mv(cap)))); + auto& map = reverse ? policy.reverseWrappers : policy.wrappers; + ClientHook*& slot = map.findOrCreate(cap.get(), [&]() -> kj::Decay::Entry { + return { cap.get(), nullptr }; + }); + if (slot == nullptr) { + auto result = ClientHook::from( + reverse ? policy.importExternal(Capability::Client(kj::mv(cap))) + : policy.exportInternal(Capability::Client(kj::mv(cap)))); + slot = result; + return result; + } else { + return slot->addRef(); + } } Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { KJ_IF_MAYBE(r, resolved) { - return r->get()->newCall(interfaceId, methodId, sizeHint); + return r->get()->newCall(interfaceId, methodId, sizeHint, hints); } auto redirect = reverse @@ -392,23 +421,24 @@ public: // otherwise behavior will differ depending on whether the promise is resolved. KJ_IF_MAYBE(p, whenMoreResolved()) { return newLocalPromiseClient(p->attach(addRef())) - ->newCall(interfaceId, methodId, sizeHint); + ->newCall(interfaceId, methodId, sizeHint, hints); } } - return ClientHook::from(kj::mv(*r))->newCall(interfaceId, methodId, sizeHint); + return ClientHook::from(kj::mv(*r))->newCall(interfaceId, methodId, sizeHint, hints); } else { // For pass-through calls, we don't worry about promises, because if the capability resolves // to something outside the membrane, then the call will pass back out of the membrane too. return MembraneRequestHook::wrap( - inner->newCall(interfaceId, methodId, sizeHint), *policy, reverse); + inner->newCall(interfaceId, methodId, sizeHint, hints), *policy, reverse); } } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { + kj::Own&& context, + CallHints hints) override { KJ_IF_MAYBE(r, resolved) { - return r->get()->call(interfaceId, methodId, kj::mv(context)); + return r->get()->call(interfaceId, methodId, kj::mv(context), hints); } auto redirect = reverse @@ -422,17 +452,21 @@ public: // otherwise behavior will differ depending on whether the promise is resolved. KJ_IF_MAYBE(p, whenMoreResolved()) { return newLocalPromiseClient(p->attach(addRef())) - ->call(interfaceId, methodId, kj::mv(context)); + ->call(interfaceId, methodId, kj::mv(context), hints); } } - return ClientHook::from(kj::mv(*r))->call(interfaceId, methodId, kj::mv(context)); + return ClientHook::from(kj::mv(*r))->call(interfaceId, methodId, kj::mv(context), hints); } else { // !reverse because calls to the CallContext go in the opposite direction. auto result = inner->call(interfaceId, methodId, - kj::refcounted(kj::mv(context), policy->addRef(), !reverse)); + kj::refcounted(kj::mv(context), policy->addRef(), !reverse), + hints); - KJ_IF_MAYBE(r, policy->onRevoked()) { + if (hints.onlyPromisePipeline) { + // Just in case the called capability returned a valid promise, replace it here. + result.promise = kj::NEVER_DONE; + } else KJ_IF_MAYBE(r, policy->onRevoked()) { result.promise = result.promise.exclusiveJoin(kj::mv(*r)); } @@ -471,11 +505,14 @@ public: } return promise->then([this](kj::Own&& newInner) { - kj::Own newResolved = wrap(*newInner, *policy, reverse); - if (resolved == nullptr) { - resolved = newResolved->addRef(); + // There's a chance resolved was set by getResolved() or a concurrent whenMoreResolved() + // while we yielded the event loop. If the inner ClientHook is maintaining the contract, + // then resolved would already be set to newInner after wrapping in a MembraneHook. + KJ_IF_MAYBE(r, resolved) { + return (*r)->addRef(); + } else { + return resolved.emplace(wrap(*newInner, *policy, reverse))->addRef(); } - return newResolved; }); } else { return nullptr; @@ -491,9 +528,11 @@ public: } kj::Maybe getFd() override { - // We can't let FDs pass over membranes because we have no way to enforce the membrane policy - // on them. If the MembranePolicy wishes to explicitly permit certain FDs to pass, it can - // always do so by overriding the appropriate policy methods. + KJ_IF_MAYBE(f, inner->getFd()) { + if (policy->allowFdPassthrough()) { + return *f; + } + } return nullptr; } @@ -505,6 +544,8 @@ private: kj::Promise revocationTask = nullptr; }; +namespace { + kj::Own membrane(kj::Own inner, MembranePolicy& policy, bool reverse) { return MembraneHook::wrap(kj::mv(inner), policy, reverse); } diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/membrane.h b/libs/EXTERNAL/capnproto/c++/src/capnp/membrane.h index d51e2308585..60629cb4ddf 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/membrane.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/membrane.h @@ -48,6 +48,9 @@ // Mark Miller on membranes: http://www.eros-os.org/pipermail/e-lang/2003-January/008434.html #include "capability.h" +#include + +CAPNP_BEGIN_HEADER namespace capnp { @@ -114,7 +117,7 @@ class MembranePolicy { // invoked for new calls, but the `target` passed to them will be a capability that always // rethrows the revocation exception. - virtual bool shouldResolveBeforeRedirecting() { return true; } + virtual bool shouldResolveBeforeRedirecting() { return false; } // If this returns true, then when inboundCall() or outboundCall() returns a redirect, but the // original target is a promise, then the membrane will discard the redirect and instead wait // for the promise to become more resolved and try again. @@ -126,12 +129,20 @@ class MembranePolicy { // capability without applying the policy at all. // // However, some membranes don't need this behavior, and may be negatively impacted by the - // unnecessary waiting. Such membranes should override this to return false. + // unnecessary waiting. Such membranes can keep this disabled. // // TODO(cleanup): Consider a backwards-incompatible revamp of the MembranePolicy API with a // better design here. Maybe we should more carefully distinguish between MembranePolicies // which are reversible vs. those which are one-way? + virtual bool allowFdPassthrough() { return false; } + // Should file descriptors be allowed to pass through this membrane? + // + // A MembranePolicy obviously cannot mediate nor revoke access to a file descriptor once it has + // passed through, so this must be used with caution. If you only want to allow file descriptors + // on certain methods, you could do so by implementing inboundCall()/outboundCall() to + // special-case those methods. + // --------------------------------------------------------------------------- // Control over importing and exporting. // @@ -181,6 +192,15 @@ class MembranePolicy { // capability passed into the membrane and then back out. // // The default implementation simply returns `external`. + +private: + kj::HashMap wrappers; + kj::HashMap reverseWrappers; + // Tracks capabilities that already have wrappers instantiated. The maps map from pointer to + // inner capability to pointer to wrapper. When a wrapper is destroyed it removes itself from + // the map. + + friend class MembraneHook; }; Capability::Client membrane(Capability::Client inner, kj::Own policy); @@ -275,3 +295,5 @@ Orphan::Reads> copyOutOfMembrane( } } // namespace capnp + +CAPNP_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/message-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/message-test.c++ index 4d7362e3ff6..6545b17a6fb 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/message-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/message-test.c++ @@ -192,6 +192,27 @@ KJ_TEST("disallow unaligned") { } #endif +KJ_TEST("MessageBuilder::sizeInWords()") { + capnp::MallocMessageBuilder builder; + auto root = builder.initRoot(); + initTestMessage(root); + + size_t expected = root.totalSize().wordCount + 1; + + KJ_EXPECT(builder.sizeInWords() == expected); + + auto segments = builder.getSegmentsForOutput(); + size_t total = 0; + for (auto& segment: segments) { + total += segment.size(); + } + KJ_EXPECT(total == expected); + + capnp::SegmentArrayMessageReader reader(segments); + checkTestMessage(reader.getRoot()); + KJ_EXPECT(reader.sizeInWords() == expected); +} + // TODO(test): More tests. } // namespace diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/message.h b/libs/EXTERNAL/capnproto/c++/src/capnp/message.h index 55a8b2e98d2..af87aec819f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/message.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/message.h @@ -127,7 +127,7 @@ class MessageReader { private: ReaderOptions options; -#if defined(__EMSCRIPTEN__) +#if defined(__EMSCRIPTEN__) || (defined(__APPLE__) && defined(__ppc__)) static constexpr size_t arenaSpacePadding = 19; #else static constexpr size_t arenaSpacePadding = 18; @@ -159,7 +159,7 @@ class MessageBuilder { public: MessageBuilder(); virtual ~MessageBuilder() noexcept(false); - KJ_DISALLOW_COPY(MessageBuilder); + KJ_DISALLOW_COPY_AND_MOVE(MessageBuilder); struct SegmentInit { kj::ArrayPtr space; @@ -343,7 +343,7 @@ class SegmentArrayMessageReader: public MessageReader { // Creates a message pointing at the given segment array, without taking ownership of the // segments. All arrays passed in must remain valid until the MessageReader is destroyed. - KJ_DISALLOW_COPY(SegmentArrayMessageReader); + KJ_DISALLOW_COPY_AND_MOVE(SegmentArrayMessageReader); ~SegmentArrayMessageReader() noexcept(false); virtual kj::ArrayPtr getSegment(uint id) override; @@ -400,7 +400,7 @@ class MallocMessageBuilder: public MessageBuilder { // firstSegment MUST be zero-initialized. MallocMessageBuilder's destructor will write new zeros // over any space that was used so that it can be reused. - KJ_DISALLOW_COPY(MallocMessageBuilder); + KJ_DISALLOW_COPY_AND_MOVE(MallocMessageBuilder); virtual ~MallocMessageBuilder() noexcept(false); virtual kj::ArrayPtr allocateSegment(uint minimumSize) override; @@ -431,7 +431,7 @@ class FlatMessageBuilder: public MessageBuilder { public: explicit FlatMessageBuilder(kj::ArrayPtr array); - KJ_DISALLOW_COPY(FlatMessageBuilder); + KJ_DISALLOW_COPY_AND_MOVE(FlatMessageBuilder); virtual ~FlatMessageBuilder() noexcept(false); void requireFilled(); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/orphan.h b/libs/EXTERNAL/capnproto/c++/src/capnp/orphan.h index ab226500db8..0ef4a671c8d 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/orphan.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/orphan.h @@ -71,7 +71,7 @@ class Orphan { // If the new size is less than the original, the remaining elements will be discarded. The // list is never moved in this case. If the list happens to be located at the end of its segment // (which is always true if the list was the last thing allocated), the removed memory will be - // reclaimed (reducing the messag size), otherwise it is simply zeroed. The reclaiming behavior + // reclaimed (reducing the message size), otherwise it is simply zeroed. The reclaiming behavior // is particularly useful for allocating buffer space when you aren't sure how much space you // actually need: you can pre-allocate, say, a 4k byte array, read() from a file into it, and // then truncate it back to the amount of space actually used. diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/persistent.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/persistent.capnp.c++ index 17ee6f45101..195c71549b3 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/persistent.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/persistent.capnp.c++ @@ -74,7 +74,7 @@ KJ_CONSTEXPR(const) ::capnp::_::RawBrandedSchema::Dependency bd_c8cb212fcd9f5691 }; const ::capnp::_::RawSchema s_c8cb212fcd9f5691 = { 0xc8cb212fcd9f5691, b_c8cb212fcd9f5691.words, 54, d_c8cb212fcd9f5691, m_c8cb212fcd9f5691, - 2, 1, nullptr, nullptr, nullptr, { &s_c8cb212fcd9f5691, nullptr, bd_c8cb212fcd9f5691, 0, sizeof(bd_c8cb212fcd9f5691) / sizeof(bd_c8cb212fcd9f5691[0]), nullptr } + 2, 1, nullptr, nullptr, nullptr, { &s_c8cb212fcd9f5691, nullptr, bd_c8cb212fcd9f5691, 0, sizeof(bd_c8cb212fcd9f5691) / sizeof(bd_c8cb212fcd9f5691[0]), nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<35> b_f76fba59183073a5 = { @@ -120,7 +120,7 @@ static const uint16_t m_f76fba59183073a5[] = {0}; static const uint16_t i_f76fba59183073a5[] = {0}; const ::capnp::_::RawSchema s_f76fba59183073a5 = { 0xf76fba59183073a5, b_f76fba59183073a5.words, 35, nullptr, m_f76fba59183073a5, - 0, 1, i_f76fba59183073a5, nullptr, nullptr, { &s_f76fba59183073a5, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_f76fba59183073a5, nullptr, nullptr, { &s_f76fba59183073a5, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<36> b_b76848c18c40efbf = { @@ -167,7 +167,7 @@ static const uint16_t m_b76848c18c40efbf[] = {0}; static const uint16_t i_b76848c18c40efbf[] = {0}; const ::capnp::_::RawSchema s_b76848c18c40efbf = { 0xb76848c18c40efbf, b_b76848c18c40efbf.words, 36, nullptr, m_b76848c18c40efbf, - 0, 1, i_b76848c18c40efbf, nullptr, nullptr, { &s_b76848c18c40efbf, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_b76848c18c40efbf, nullptr, nullptr, { &s_b76848c18c40efbf, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<22> b_f622595091cafb67 = { @@ -198,7 +198,7 @@ static const ::capnp::_::AlignedData<22> b_f622595091cafb67 = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_f622595091cafb67 = { 0xf622595091cafb67, b_f622595091cafb67.words, 22, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_f622595091cafb67, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_f622595091cafb67, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/persistent.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/persistent.capnp.h index 60ea65b24fe..ee0abb37f1a 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/persistent.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/persistent.capnp.h @@ -9,7 +9,9 @@ #include #endif // !CAPNP_LITE -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif @@ -430,15 +432,19 @@ inline ::capnp::Orphan Persistent::SaveParams::Builder: } // Persistent::SaveParams +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr uint16_t Persistent::SaveParams::_capnpPrivate::dataWordSize; template constexpr uint16_t Persistent::SaveParams::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr ::capnp::Kind Persistent::SaveParams::_capnpPrivate::kind; template constexpr ::capnp::_::RawSchema const* Persistent::SaveParams::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template const ::capnp::_::RawBrandedSchema::Scope Persistent::SaveParams::_capnpPrivate::brandScopes[] = { { 0xc8cb212fcd9f5691, brandBindings + 0, 2, false}, @@ -509,15 +515,19 @@ inline ::capnp::Orphan Persistent::SaveResults::Bui } // Persistent::SaveResults +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr uint16_t Persistent::SaveResults::_capnpPrivate::dataWordSize; template constexpr uint16_t Persistent::SaveResults::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr ::capnp::Kind Persistent::SaveResults::_capnpPrivate::kind; template constexpr ::capnp::_::RawSchema const* Persistent::SaveResults::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template const ::capnp::_::RawBrandedSchema::Scope Persistent::SaveResults::_capnpPrivate::brandScopes[] = { { 0xc8cb212fcd9f5691, brandBindings + 0, 2, false}, @@ -539,7 +549,7 @@ template CAPNP_AUTO_IF_MSVC(::capnp::Request::SaveParams, typename ::capnp::Persistent::SaveResults>) Persistent::Client::saveRequest(::kj::Maybe< ::capnp::MessageSize> sizeHint) { return newCall::SaveParams, typename ::capnp::Persistent::SaveResults>( - 0xc8cb212fcd9f5691ull, 0, sizeHint); + 0xc8cb212fcd9f5691ull, 0, sizeHint, {false}); } template ::kj::Promise Persistent::Server::save(SaveContext) { @@ -567,6 +577,7 @@ ::capnp::Capability::Server::DispatchCallResult Persistent::Se return { save(::capnp::Capability::Server::internalGetTypedContext< typename ::capnp::Persistent::SaveParams, typename ::capnp::Persistent::SaveResults>(context)), + false, false }; default: @@ -580,10 +591,12 @@ ::capnp::Capability::Server::DispatchCallResult Persistent::Se // Persistent #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr ::capnp::Kind Persistent::_capnpPrivate::kind; template constexpr ::capnp::_::RawSchema const* Persistent::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template const ::capnp::_::RawBrandedSchema::Scope Persistent::_capnpPrivate::brandScopes[] = { { 0xc8cb212fcd9f5691, brandBindings + 0, 2, false}, diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/raw-schema.h b/libs/EXTERNAL/capnproto/c++/src/capnp/raw-schema.h index 88101692a20..44b696c5ca1 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/raw-schema.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/raw-schema.h @@ -226,6 +226,9 @@ struct RawSchema { // Specifies the brand to use for this schema if no generic parameters have been bound to // anything. Generally, in the default brand, all generic parameters are treated as if they were // bound to `AnyPointer`. + + bool mayContainCapabilities = true; + // See StructSchema::mayContainCapabilities. }; inline bool RawBrandedSchema::isUnbound() const { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect-test.c++ index 492f5489be3..12ef59333d1 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect-test.c++ @@ -85,7 +85,8 @@ void doAutoReconnectTest(kj::WaitScope& ws, KJ_EXPECT(test(123, true) == "123 true 0"); currentServer->setError(KJ_EXCEPTION(DISCONNECTED, "test1 disconnect")); - KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test1 disconnect", test(456, true)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test1 disconnect", + testPromise(456, true).ignoreResult().wait(ws)); KJ_EXPECT(test(789, false) == "789 false 1"); KJ_EXPECT(test(21, true) == "21 true 1"); @@ -99,8 +100,8 @@ void doAutoReconnectTest(kj::WaitScope& ws, KJ_EXPECT(!promise1.poll(ws)); KJ_EXPECT(!promise2.poll(ws)); fulfiller->reject(KJ_EXCEPTION(DISCONNECTED, "test2 disconnect")); - KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test2 disconnect", promise1.wait(ws)); - KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test2 disconnect", promise2.wait(ws)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test2 disconnect", promise1.ignoreResult().wait(ws)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test2 disconnect", promise2.ignoreResult().wait(ws)); } KJ_EXPECT(test(43, false) == "43 false 2"); @@ -127,7 +128,7 @@ void doAutoReconnectTest(kj::WaitScope& ws, client = nullptr; // Everything we initiated should still finish. - KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test3 disconnect", promise4.wait(ws)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test3 disconnect", promise4.ignoreResult().wait(ws)); // Send the request which we created before the disconnect. There are two behaviors we accept // as correct here: it may throw the disconnect exception, or it may automatically redirect to @@ -195,6 +196,12 @@ KJ_TEST("lazyAutoReconnect() initialies lazily") { req.setJ(j); return kj::str(req.send().wait(ws).getX()); }; + auto testIgnoreResult = [&](uint i, bool j) { + auto req = client.fooRequest(); + req.setI(i); + req.setJ(j); + req.send().ignoreResult().wait(ws); + }; KJ_EXPECT(connectCount == 1); KJ_EXPECT(test(123, true) == "123 true 0"); @@ -208,7 +215,7 @@ KJ_TEST("lazyAutoReconnect() initialies lazily") { KJ_EXPECT(connectCount == 2); currentServer->setError(KJ_EXCEPTION(DISCONNECTED, "test1 disconnect")); - KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test1 disconnect", test(345, true)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test1 disconnect", testIgnoreResult(345, true)); // lazyAutoReconnect is only lazy on the first request, not on reconnects. KJ_EXPECT(connectCount == 3); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect.c++ index fe2bd07f8dc..2a8c67f6c23 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect.c++ @@ -32,17 +32,26 @@ public: current(lazy ? kj::Maybe>() : ClientHook::from(connect())) {} Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { - auto result = getCurrent().newCall(interfaceId, methodId, sizeHint); + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { + auto result = getCurrent().newCall(interfaceId, methodId, sizeHint, hints); AnyPointer::Builder builder = result; auto hook = kj::heap(kj::addRef(*this), RequestHook::from(kj::mv(result))); return { builder, kj::mv(hook) }; } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { - auto result = getCurrent().call(interfaceId, methodId, kj::mv(context)); - wrap(result.promise); + kj::Own&& context, CallHints hints) override { + auto result = getCurrent().call(interfaceId, methodId, kj::mv(context), hints); + if (hints.onlyPromisePipeline) { + // Just in case the callee didn't implement the hint, replace its promise. + result.promise = kj::NEVER_DONE; + + // TODO(bug): In this case we won't detect cancellation. This is essentially the same + // bug as described in `RequestImpl::send()` below, and will need the same solution. + } else { + wrap(result.promise); + } return result; } @@ -109,6 +118,12 @@ private: RemotePromise send() override { auto result = inner->send(); + // TODO(bug): If the returned promise is dropped, e.g. because the caller only cares about + // pipelining, then the DISCONNECTED exception will not be noticed. I suppose we have to + // split the promise and hold one branch, but we don't want to prevent cancellation, so + // we only want to hold that branch as long as the PipelineHook or some pipelined + // capability obtained through it lives. So we need a bunch of custom wrappers for that. + // Ugh. parent->wrap(result); return result; } @@ -119,6 +134,11 @@ private: return result; } + AnyPointer::Pipeline sendForPipeline() override { + // TODO(bug): This definitely fails to detect disconnects; see comment in send(). + return inner->sendForPipeline(); + } + const void* getBrand() override { return nullptr; } diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect.h b/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect.h index 6f7d3d62d0d..4e430951e92 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/reconnect.h @@ -21,7 +21,7 @@ #pragma once -#include "capability.h" +#include #include CAPNP_BEGIN_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-prelude.h b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-prelude.h index c6165d13233..742aa868c98 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-prelude.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-prelude.h @@ -24,7 +24,7 @@ #pragma once -#include "capability.h" +#include #include "persistent.capnp.h" CAPNP_BEGIN_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-test.c++ index 6211d7e3649..da0d7abcbc7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-test.c++ @@ -288,8 +288,8 @@ public: auto incomingMessage = kj::heap(messageToFlatArray(message)); auto connectionPtr = &connection; - connection.tasks->add(kj::evalLater(kj::mvCapture(incomingMessage, - [connectionPtr](kj::Own&& message) { + connection.tasks->add(kj::evalLater( + [connectionPtr,message=kj::mv(incomingMessage)]() mutable { KJ_IF_MAYBE(p, connectionPtr->partner) { if (p->fulfillers.empty()) { p->messages.push(kj::mv(message)); @@ -300,7 +300,7 @@ public: p->fulfillers.pop(); } } - }))); + })); } size_t sizeInWords() override { @@ -566,6 +566,42 @@ TEST(Rpc, Pipelining) { EXPECT_EQ(1, chainedCallCount); } +KJ_TEST("RPC sendForPipeline()") { + TestContext context; + + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_PIPELINE) + .castAs(); + + int chainedCallCount = 0; + + auto request = client.getCapRequest(); + request.setN(234); + request.setInCap(kj::heap(chainedCallCount)); + + auto pipeline = request.sendForPipeline(); + + auto pipelineRequest = pipeline.getOutBox().getCap().fooRequest(); + pipelineRequest.setI(321); + auto pipelinePromise = pipelineRequest.send(); + + auto pipelineRequest2 = pipeline.getOutBox().getCap().castAs().graultRequest(); + auto pipelinePromise2 = pipelineRequest2.send(); + + pipeline = nullptr; // Just to be annoying, drop the original pipeline. + + EXPECT_EQ(0, context.restorer.callCount); + EXPECT_EQ(0, chainedCallCount); + + auto response = pipelinePromise.wait(context.waitScope); + EXPECT_EQ("bar", response.getX()); + + auto response2 = pipelinePromise2.wait(context.waitScope); + checkTestMessage(response2); + + EXPECT_EQ(3, context.restorer.callCount); + EXPECT_EQ(1, chainedCallCount); +} + KJ_TEST("RPC context.setPipeline") { TestContext context; @@ -703,7 +739,6 @@ public: : callCount(callCount), cancelCount(cancelCount) {} kj::Promise foo(FooContext context) override { - context.allowCancellation(); ++callCount; return kj::Promise(kj::NEVER_DONE) .attach(kj::defer([&cancelCount = cancelCount]() { ++cancelCount; })); @@ -798,8 +833,8 @@ TEST(Rpc, TailCallCancelRace) { KJ_ASSERT(cancelCount == 1); } -TEST(Rpc, Cancelation) { - // Tests allowCancellation(). +TEST(Rpc, Cancellation) { + // Tests cancellation. TestContext context; @@ -1305,6 +1340,23 @@ KJ_TEST("method throws exception") { KJ_EXPECT(exception.getRemoteTrace() == nullptr); } +KJ_TEST("method throws exception won't redundantly add remote exception prefix") { + TestContext context; + + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) + .castAs(); + + kj::Maybe maybeException; + client.throwRemoteExceptionRequest().send().ignoreResult() + .catch_([&](kj::Exception&& e) { + maybeException = kj::mv(e); + }).wait(context.waitScope); + + auto exception = KJ_ASSERT_NONNULL(maybeException); + KJ_EXPECT(exception.getDescription() == "remote exception: test exception"); + KJ_EXPECT(exception.getRemoteTrace() == nullptr); +} + KJ_TEST("method throws exception with trace encoder") { TestContext context; @@ -1510,7 +1562,7 @@ KJ_TEST("export the same promise twice") { KJ_EXPECT(interceptCount == 3); // Now try sending a non-promise cap. We'll send all these requests at once before waiting on - // any of them since these will acutally complete.k + // any of them since these will actually complete. exportIsPromise = false; expectedExportNumber = 2; auto promise4 = sendReq(normalCap); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty-test.c++ index a3e5749c7c0..5bf2215de52 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty-test.c++ @@ -147,6 +147,7 @@ TEST(TwoPartyNetwork, Basic) { clock.increment(1 * kj::SECONDS); KJ_EXPECT(network.getCurrentQueueCount() == 1); + KJ_EXPECT(network.getCurrentQueueSize() % sizeof(word) == 0); KJ_EXPECT(network.getCurrentQueueSize() > 0); KJ_EXPECT(network.getOutgoingMessageWaitTime() == 1 * kj::SECONDS); size_t oldSize = network.getCurrentQueueSize(); @@ -158,6 +159,7 @@ TEST(TwoPartyNetwork, Basic) { auto promise1 = request1.send(); KJ_EXPECT(network.getCurrentQueueCount() == 2); + KJ_EXPECT(network.getCurrentQueueSize() % sizeof(word) == 0); KJ_EXPECT(network.getCurrentQueueSize() > oldSize); KJ_EXPECT(network.getOutgoingMessageWaitTime() == 1 * kj::SECONDS); oldSize = network.getCurrentQueueSize(); @@ -167,6 +169,7 @@ TEST(TwoPartyNetwork, Basic) { auto promise2 = request2.send(); KJ_EXPECT(network.getCurrentQueueCount() == 3); + KJ_EXPECT(network.getCurrentQueueSize() % sizeof(word) == 0); KJ_EXPECT(network.getCurrentQueueSize() > oldSize); oldSize = network.getCurrentQueueSize(); @@ -184,6 +187,7 @@ TEST(TwoPartyNetwork, Basic) { EXPECT_EQ(0, callCount); KJ_EXPECT(network.getCurrentQueueCount() == 4); + KJ_EXPECT(network.getCurrentQueueSize() % sizeof(word) == 0); KJ_EXPECT(network.getCurrentQueueSize() > oldSize); // Oldest message is now 2 seconds old KJ_EXPECT(network.getOutgoingMessageWaitTime() == 2 * kj::SECONDS); @@ -213,6 +217,18 @@ TEST(TwoPartyNetwork, Basic) { // Now nothing is queued. KJ_EXPECT(network.getCurrentQueueCount() == 0); KJ_EXPECT(network.getCurrentQueueSize() == 0); + + // Ensure that sending a message after not sending one for some time + // doesn't return incorrect waitTime statistics. + clock.increment(10 * kj::SECONDS); + + auto request4 = client.fooRequest(); + request4.setI(123); + request4.setJ(true); + auto promise4 = request4.send(); + + KJ_EXPECT(network.getCurrentQueueCount() == 1); + KJ_EXPECT(network.getOutgoingMessageWaitTime() == 0 * kj::SECONDS); } TEST(TwoPartyNetwork, Pipelining) { @@ -390,8 +406,10 @@ TEST(TwoPartyNetwork, Abort) { msg->send(); } - auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope)); - EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs().which()); + { + auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope)); + EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs().which()); + } EXPECT_TRUE(conn->receiveIncomingMessage().wait(ioContext.waitScope) == nullptr); } @@ -720,6 +738,85 @@ KJ_TEST("Streaming over RPC") { } } +KJ_TEST("Streaming over a chain of local and remote RPC calls") { + // This test verifies that a local RPC call that eventually resolves to a remote RPC call will + // still support streaming calls over the remote connection. + + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + // Set up a local server that will eventually delegate requests to a remote server. + auto localPaf = kj::newPromiseAndFulfiller(); + test::TestStreaming::Client promisedClient(kj::mv(localPaf.promise)); + + uint count = 0; + auto req = promisedClient.doStreamIRequest(); + req.setI(++count); + auto promise = req.send(); + + // Expect streaming request to be blocked on promised client. + KJ_EXPECT(!promise.poll(waitScope)); + + // Set up a remote server with a flow control window for streaming. + auto pipe = kj::newTwoWayPipe(); + + size_t window = 1024; + size_t clientWritten = 0; + size_t serverWritten = 0; + + pipe.ends[0] = kj::heap(kj::mv(pipe.ends[0]), window, clientWritten); + pipe.ends[1] = kj::heap(kj::mv(pipe.ends[1]), window, serverWritten); + + auto remotePaf = kj::newPromiseAndFulfiller(); + test::TestStreaming::Client serverCap(kj::mv(remotePaf.promise)); + + TwoPartyClient tpClient(*pipe.ends[0]); + TwoPartyClient tpServer(*pipe.ends[1], kj::mv(serverCap), rpc::twoparty::Side::SERVER); + + auto clientCap = tpClient.bootstrap().castAs(); + + // Expect streaming request to be unblocked by fulfilling promised client with remote server. + localPaf.fulfiller->fulfill(kj::mv(clientCap)); + KJ_EXPECT(promise.poll(waitScope)); + + // Send stream requests until we can't anymore. + while (promise.poll(waitScope)) { + promise.wait(waitScope); + + auto req = promisedClient.doStreamIRequest(); + req.setI(++count); + promise = req.send(); + KJ_ASSERT(count < 1000); + } + + // Expect several stream requests to have fit in the flow control window. + KJ_EXPECT(count > 5); + + auto finishReq = promisedClient.finishStreamRequest(); + auto finishPromise = finishReq.send(); + KJ_EXPECT(!finishPromise.poll(waitScope)); + + // Finish calls on server + auto ownServer = kj::heap(); + auto& server = *ownServer; + remotePaf.fulfiller->fulfill(kj::mv(ownServer)); + KJ_EXPECT(!promise.poll(waitScope)); + + uint countReceived = 0; + for (uint i = 0; i < count; i++) { + KJ_EXPECT(server.iSum == ++countReceived); + server.iSum = 0; + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + if (i < count - 1) { + KJ_EXPECT(!finishPromise.poll(waitScope)); + } + } + + KJ_EXPECT(finishPromise.poll(waitScope)); + finishPromise.wait(waitScope); +} + KJ_TEST("Streaming over RPC then unwrap with CapabilitySet") { kj::EventLoop loop; kj::WaitScope waitScope(loop); @@ -891,6 +988,284 @@ KJ_TEST("write error propagates to read error") { } } +class TestStreamingCancellationBug final: public test::TestStreaming::Server { +public: + uint iSum = 0; + kj::Maybe>> fulfiller; + + kj::Promise doStreamI(DoStreamIContext context) override { + auto paf = kj::newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + return paf.promise.then([this,context]() mutable { + // Don't count the sum until here so we actually detect if the call is canceled. + iSum += context.getParams().getI(); + }); + } + + kj::Promise finishStream(FinishStreamContext context) override { + auto results = context.getResults(); + results.setTotalI(iSum); + return kj::READY_NOW; + } +}; + +KJ_TEST("Streaming over RPC no premature cancellation when client dropped") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + + auto ownServer = kj::heap(); + auto& server = *ownServer; + test::TestStreaming::Client serverCap = kj::mv(ownServer); + + TwoPartyClient tpClient(*pipe.ends[0]); + TwoPartyClient tpServer(*pipe.ends[1], kj::mv(serverCap), rpc::twoparty::Side::SERVER); + + auto client = tpClient.bootstrap().castAs(); + + kj::Promise promise1 = nullptr, promise2 = nullptr; + + { + auto req = client.doStreamIRequest(); + req.setI(123); + promise1 = req.send(); + } + { + auto req = client.doStreamIRequest(); + req.setI(456); + promise2 = req.send(); + } + + auto finishPromise = client.finishStreamRequest().send(); + + KJ_EXPECT(server.iSum == 0); + + // Drop the client. This shouldn't cause a problem for the already-running RPCs. + { auto drop = kj::mv(client); } + + while (!finishPromise.poll(waitScope)) { + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + } + + finishPromise.wait(waitScope); + KJ_EXPECT(server.iSum == 579); +} + +KJ_TEST("Dropping capability during call doesn't destroy server") { + class TestInterfaceImpl final: public test::TestInterface::Server { + // An object which increments a count in the constructor and decrements it in the destructor, + // to detect when it is destroyed. The object's foo() method also sets a fulfiller to use to + // cause the method to complete. + public: + TestInterfaceImpl(uint& count, kj::Maybe>>& fulfillerSlot) + : count(count), fulfillerSlot(fulfillerSlot) { ++count; } + ~TestInterfaceImpl() noexcept(false) { --count; } + + kj::Promise foo(FooContext context) override { + auto paf = kj::newPromiseAndFulfiller(); + fulfillerSlot = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + + private: + uint& count; + kj::Maybe>>& fulfillerSlot; + }; + + class TestBootstrapImpl final: public test::TestMoreStuff::Server { + // Bootstrap object which just vends instances of `TestInterfaceImpl`. + public: + TestBootstrapImpl(uint& count, kj::Maybe>>& fulfillerSlot) + : count(count), fulfillerSlot(fulfillerSlot) {} + + kj::Promise getHeld(GetHeldContext context) override { + context.initResults().setCap(kj::heap(count, fulfillerSlot)); + return kj::READY_NOW; + } + + private: + uint& count; + kj::Maybe>>& fulfillerSlot; + }; + + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + auto pipe = kj::newTwoWayPipe(); + + uint count = 0; + kj::Maybe>> fulfillerSlot; + test::TestMoreStuff::Client bootstrap = kj::heap(count, fulfillerSlot); + + TwoPartyClient tpClient(*pipe.ends[0]); + TwoPartyClient tpServer(*pipe.ends[1], kj::mv(bootstrap), rpc::twoparty::Side::SERVER); + + auto cap = tpClient.bootstrap().castAs().getHeldRequest().send().getCap(); + + waitScope.poll(); + auto promise = cap.fooRequest().send(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(fulfillerSlot != nullptr); + + // Dropping the capability should not destroy the server as long as the call is still + // outstanding. + {auto drop = kj::mv(cap);} + + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(count == 1); + + // Cancelling the call still should not destroy the server because the call is not marked to + // allow cancellation. So the call should keep running. + {auto drop = kj::mv(promise);} + + waitScope.poll(); + KJ_EXPECT(count == 1); + + // When the call completes, only then should the server be dropped. + KJ_ASSERT_NONNULL(fulfillerSlot)->fulfill(); + + waitScope.poll(); + KJ_EXPECT(count == 0); +} + +RemotePromise getCallSequence( + test::TestCallOrder::Client& client, uint expected) { + auto req = client.getCallSequenceRequest(); + req.setExpected(expected); + return req.send(); +} + +KJ_TEST("Two-hop embargo") { + // Copied from `TEST(Rpc, Embargo)` in `rpc-test.c++`, adapted to involve a two-hop path through + // a proxy. This tests what happens when disembargoes on multiple hops are happening in parallel. + + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + int callCount = 0, handleCount = 0; + + // Set up two two-party RPC connections in series. The middle node just proxies requests through. + auto frontPipe = kj::newTwoWayPipe(); + auto backPipe = kj::newTwoWayPipe(); + TwoPartyClient tpClient(*frontPipe.ends[0]); + TwoPartyClient proxyBack(*backPipe.ends[0]); + TwoPartyClient proxyFront(*frontPipe.ends[1], proxyBack.bootstrap(), rpc::twoparty::Side::SERVER); + TwoPartyClient tpServer(*backPipe.ends[1], kj::heap(callCount, handleCount), + rpc::twoparty::Side::SERVER); + + // Perform some logic that does a bunch of promise pipelining, including passing a capability + // from the client to the server and back to the client, and making promise-pipelined calls on + // that capability. This should exercise the promise resolution and disembargo code. + auto client = tpClient.bootstrap().castAs(); + + auto cap = test::TestCallOrder::Client(kj::heap()); + + auto earlyCall = client.getCallSequenceRequest().send(); + + auto echoRequest = client.echoRequest(); + echoRequest.setCap(cap); + auto echo = echoRequest.send(); + + auto pipeline = echo.getCap(); + + auto call0 = getCallSequence(pipeline, 0); + auto call1 = getCallSequence(pipeline, 1); + + earlyCall.wait(waitScope); + + auto call2 = getCallSequence(pipeline, 2); + + auto resolved = echo.wait(waitScope).getCap(); + + auto call3 = getCallSequence(pipeline, 3); + auto call4 = getCallSequence(pipeline, 4); + auto call5 = getCallSequence(pipeline, 5); + + EXPECT_EQ(0, call0.wait(waitScope).getN()); + EXPECT_EQ(1, call1.wait(waitScope).getN()); + EXPECT_EQ(2, call2.wait(waitScope).getN()); + EXPECT_EQ(3, call3.wait(waitScope).getN()); + EXPECT_EQ(4, call4.wait(waitScope).getN()); + EXPECT_EQ(5, call5.wait(waitScope).getN()); +} + +class TestCallOrderImplAsPromise final: public test::TestCallOrder::Server { + // This is an implementation of TestCallOrder that presents itself as a promise by implementing + // `shortenPath()`, although it never resolves to anything (`shortenPath()` never completes). + // This tests deeper code paths in promise resolution and embargo code. +public: + template + TestCallOrderImplAsPromise(Params&&... params): inner(kj::fwd(params)...) {} + + kj::Promise getCallSequence(GetCallSequenceContext context) override { + return inner.getCallSequence(context); + } + + kj::Maybe> shortenPath() override { + // Make this object appear to be a promise. + return kj::Promise(kj::NEVER_DONE); + } + +private: + TestCallOrderImpl inner; +}; + +KJ_TEST("Two-hop embargo") { + // Same as above, but the eventual resolution is itself a promise. This verifies that + // handleDisembargo() only waits for the target to resolve back to the capability that the + // disembargo should reflect to, but not beyond that. + + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + int callCount = 0, handleCount = 0; + + // Set up two two-party RPC connections in series. The middle node just proxies requests through. + auto frontPipe = kj::newTwoWayPipe(); + auto backPipe = kj::newTwoWayPipe(); + TwoPartyClient tpClient(*frontPipe.ends[0]); + TwoPartyClient proxyBack(*backPipe.ends[0]); + TwoPartyClient proxyFront(*frontPipe.ends[1], proxyBack.bootstrap(), rpc::twoparty::Side::SERVER); + TwoPartyClient tpServer(*backPipe.ends[1], kj::heap(callCount, handleCount), + rpc::twoparty::Side::SERVER); + + // Perform some logic that does a bunch of promise pipelining, including passing a capability + // from the client to the server and back to the client, and making promise-pipelined calls on + // that capability. This should exercise the promise resolution and disembargo code. + auto client = tpClient.bootstrap().castAs(); + + auto cap = test::TestCallOrder::Client(kj::heap()); + + auto earlyCall = client.getCallSequenceRequest().send(); + + auto echoRequest = client.echoRequest(); + echoRequest.setCap(cap); + auto echo = echoRequest.send(); + + auto pipeline = echo.getCap(); + + auto call0 = getCallSequence(pipeline, 0); + auto call1 = getCallSequence(pipeline, 1); + + earlyCall.wait(waitScope); + + auto call2 = getCallSequence(pipeline, 2); + + auto resolved = echo.wait(waitScope).getCap(); + + auto call3 = getCallSequence(pipeline, 3); + auto call4 = getCallSequence(pipeline, 4); + auto call5 = getCallSequence(pipeline, 5); + + EXPECT_EQ(0, call0.wait(waitScope).getN()); + EXPECT_EQ(1, call1.wait(waitScope).getN()); + EXPECT_EQ(2, call2.wait(waitScope).getN()); + EXPECT_EQ(3, call3.wait(waitScope).getN()); + EXPECT_EQ(4, call4.wait(waitScope).getN()); + EXPECT_EQ(5, call5.wait(waitScope).getN()); +} + } // namespace } // namespace _ } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.c++ index f1bede8cb8a..09c84bbb55d 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.c++ @@ -66,14 +66,20 @@ TwoPartyVatNetwork::TwoPartyVatNetwork( TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, ReaderOptions receiveOptions, const kj::MonotonicClock& clock) - : TwoPartyVatNetwork(kj::Own(kj::heap(stream)), - 0, side, receiveOptions, clock) {} + : TwoPartyVatNetwork( + kj::Own(kj::heap( + stream, IncomingRpcMessage::getShortLivedCallback())), + 0, side, receiveOptions, clock) {} TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMessage, rpc::twoparty::Side side, ReaderOptions receiveOptions, const kj::MonotonicClock& clock) - : TwoPartyVatNetwork(kj::Own(kj::heap(stream)), - maxFdsPerMessage, side, receiveOptions, clock) {} + : TwoPartyVatNetwork( + kj::Own(kj::heap( + stream, IncomingRpcMessage::getShortLivedCallback())), + maxFdsPerMessage, side, receiveOptions, clock) {} + +TwoPartyVatNetwork::~TwoPartyVatNetwork() noexcept(false) {}; MessageStream& TwoPartyVatNetwork::getStream() { KJ_SWITCH_ONEOF(stream) { @@ -148,19 +154,44 @@ public: return; } - network.currentQueueSize += size * sizeof(capnp::word); - ++network.currentQueueCount; - auto deferredSizeUpdate = kj::defer([&network = network, size]() mutable { - network.currentQueueSize -= size * sizeof(capnp::word); - --network.currentQueueCount; - }); - auto sendTime = network.clock.now(); - network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down") - .then([this, sendTime]() { - return kj::evalNow([&]() { + if (network.queuedMessages.size() == 0) { + // Optimistically set sendTime when there's no messages in the queue. Without this, sending + // a message after a long delay could cause getOutgoingMessageWaitTime() to return excessively + // long wait times if it is called during the time period after send() is called, + // but before the write occurs, as we increment currentQueueCount synchronously, but + // asynchronously update currentOutgoingMessageSendTime. + network.currentOutgoingMessageSendTime = sendTime; + } + + // Instead of sending each new message as soon as possible, we attempt to batch together small + // messages by delaying when we send them using evalLast. This allows us to group together + // related small messages, reducing the number of syscalls we make. + auto& previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down"); + bool alreadyPendingSend = !network.queuedMessages.empty(); + network.currentQueueSize += message.sizeInWords() * sizeof(word); + network.queuedMessages.add(kj::addRef(*this)); + if (alreadyPendingSend) { + // The first send sets up an evalLast that will clear out pendingMessages when it's sent. + // If pendingMessages is non-empty, then there must already be a callback waiting to send + // them. + return; + } + + // On the other hand, if pendingMessages was empty, then we should set up the delayed write. + network.previousWrite = previousWrite.then([this, sendTime]() { + return kj::evalLast([this, sendTime]() -> kj::Promise { network.currentOutgoingMessageSendTime = sendTime; - return network.getStream().writeMessage(fds, message); + // Swap out the connection's pending messages and write all of them together. + auto ownMessages = kj::mv(network.queuedMessages); + network.currentQueueSize = 0; + auto messages = + kj::heapArray(ownMessages.size()); + for (int i = 0; i < messages.size(); ++i) { + messages[i].segments = ownMessages[i]->message.getSegmentsForOutput(); + messages[i].fds = ownMessages[i]->fds; + } + return network.getStream().writeMessages(messages).attach(kj::mv(ownMessages), kj::mv(messages)); }).catch_([this](kj::Exception&& e) { // Since no one checks write failures, we need to propagate them into read failures, // otherwise we might get stuck sending all messages into a black hole and wondering why @@ -171,7 +202,7 @@ public: } kj::throwRecoverableException(kj::mv(e)); }); - }).attach(kj::addRef(*this), kj::mv(deferredSizeUpdate)) + }).attach(kj::addRef(*this)) // Note that it's important that the eagerlyEvaluate() come *after* the attach() because // otherwise the message (and any capabilities in it) will not be released until a new // message is written! (Kenton once spent all afternoon tracking this down...) @@ -189,7 +220,7 @@ private: }; kj::Duration TwoPartyVatNetwork::getOutgoingMessageWaitTime() { - if (currentQueueCount > 0) { + if (queuedMessages.size() > 0) { return clock.now() - currentOutgoingMessageSendTime; } else { return 0 * kj::SECONDS; @@ -310,31 +341,46 @@ kj::Promise TwoPartyVatNetwork::shutdown() { // ======================================================================================= -TwoPartyServer::TwoPartyServer(Capability::Client bootstrapInterface) - : bootstrapInterface(kj::mv(bootstrapInterface)), tasks(*this) {} +TwoPartyServer::TwoPartyServer(Capability::Client bootstrapInterface, + kj::Maybe> traceEncoder) + : bootstrapInterface(kj::mv(bootstrapInterface)), + traceEncoder(kj::mv(traceEncoder)), + tasks(*this) {} struct TwoPartyServer::AcceptedConnection { kj::Own connection; TwoPartyVatNetwork network; RpcSystem rpcSystem; - explicit AcceptedConnection(Capability::Client bootstrapInterface, + explicit AcceptedConnection(TwoPartyServer& parent, kj::Own&& connectionParam) : connection(kj::mv(connectionParam)), network(*connection, rpc::twoparty::Side::SERVER), - rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {} + rpcSystem(makeRpcServer(network, kj::cp(parent.bootstrapInterface))) { + init(parent); + } - explicit AcceptedConnection(Capability::Client bootstrapInterface, + explicit AcceptedConnection(TwoPartyServer& parent, kj::Own&& connectionParam, uint maxFdsPerMessage) : connection(kj::mv(connectionParam)), network(kj::downcast(*connection), maxFdsPerMessage, rpc::twoparty::Side::SERVER), - rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {} + rpcSystem(makeRpcServer(network, kj::cp(parent.bootstrapInterface))) { + init(parent); + } + + void init(TwoPartyServer& parent) { + KJ_IF_MAYBE(t, parent.traceEncoder) { + rpcSystem.setTraceEncoder([&func = *t](const kj::Exception& e) { + return func(e); + }); + } + } }; void TwoPartyServer::accept(kj::Own&& connection) { - auto connectionState = kj::heap(bootstrapInterface, kj::mv(connection)); + auto connectionState = kj::heap(*this, kj::mv(connection)); // Run the connection until disconnect. auto promise = connectionState->network.onDisconnect(); @@ -344,7 +390,7 @@ void TwoPartyServer::accept(kj::Own&& connection) { void TwoPartyServer::accept( kj::Own&& connection, uint maxFdsPerMessage) { auto connectionState = kj::heap( - bootstrapInterface, kj::mv(connection), maxFdsPerMessage); + *this, kj::mv(connection), maxFdsPerMessage); // Run the connection until disconnect. auto promise = connectionState->network.onDisconnect(); @@ -352,7 +398,7 @@ void TwoPartyServer::accept( } kj::Promise TwoPartyServer::accept(kj::AsyncIoStream& connection) { - auto connectionState = kj::heap(bootstrapInterface, + auto connectionState = kj::heap(*this, kj::Own(&connection, kj::NullDisposer::instance)); // Run the connection until disconnect. @@ -362,7 +408,7 @@ kj::Promise TwoPartyServer::accept(kj::AsyncIoStream& connection) { kj::Promise TwoPartyServer::accept( kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) { - auto connectionState = kj::heap(bootstrapInterface, + auto connectionState = kj::heap(*this, kj::Own(&connection, kj::NullDisposer::instance), maxFdsPerMessage); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp index 0b670e8ac3f..5f0e2150e7f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp @@ -162,8 +162,6 @@ struct JoinResult { # implements the join by waiting for all the `JoinKeyParts` and then performing its own join on # them, then going back and answering all the join requests afterwards. - cap @2 :AnyPointer; + cap @2 :Capability; # One of the JoinResults will have a non-null `cap` which is the joined capability. - # - # TODO(cleanup): Change `AnyPointer` to `Capability` when that is supported. } diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp.c++ index 64ae32bf2e5..6809cebcf42 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp.c++ @@ -38,7 +38,7 @@ static const ::capnp::_::AlignedData<26> b_9fd69ebc87b9719c = { static const uint16_t m_9fd69ebc87b9719c[] = {1, 0}; const ::capnp::_::RawSchema s_9fd69ebc87b9719c = { 0x9fd69ebc87b9719c, b_9fd69ebc87b9719c.words, 26, nullptr, m_9fd69ebc87b9719c, - 0, 2, nullptr, nullptr, nullptr, { &s_9fd69ebc87b9719c, nullptr, nullptr, 0, 0, nullptr } + 0, 2, nullptr, nullptr, nullptr, { &s_9fd69ebc87b9719c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE CAPNP_DEFINE_ENUM(Side_9fd69ebc87b9719c, 9fd69ebc87b9719c); @@ -86,7 +86,7 @@ static const uint16_t m_d20b909fee733a8e[] = {0}; static const uint16_t i_d20b909fee733a8e[] = {0}; const ::capnp::_::RawSchema s_d20b909fee733a8e = { 0xd20b909fee733a8e, b_d20b909fee733a8e.words, 33, d_d20b909fee733a8e, m_d20b909fee733a8e, - 1, 1, i_d20b909fee733a8e, nullptr, nullptr, { &s_d20b909fee733a8e, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_d20b909fee733a8e, nullptr, nullptr, { &s_d20b909fee733a8e, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<34> b_b88d09a9c5f39817 = { @@ -131,7 +131,7 @@ static const uint16_t m_b88d09a9c5f39817[] = {0}; static const uint16_t i_b88d09a9c5f39817[] = {0}; const ::capnp::_::RawSchema s_b88d09a9c5f39817 = { 0xb88d09a9c5f39817, b_b88d09a9c5f39817.words, 34, nullptr, m_b88d09a9c5f39817, - 0, 1, i_b88d09a9c5f39817, nullptr, nullptr, { &s_b88d09a9c5f39817, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_b88d09a9c5f39817, nullptr, nullptr, { &s_b88d09a9c5f39817, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<18> b_89f389b6fd4082c1 = { @@ -158,7 +158,7 @@ static const ::capnp::_::AlignedData<18> b_89f389b6fd4082c1 = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_89f389b6fd4082c1 = { 0x89f389b6fd4082c1, b_89f389b6fd4082c1.words, 18, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_89f389b6fd4082c1, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_89f389b6fd4082c1, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<19> b_b47f4979672cb59d = { @@ -186,7 +186,7 @@ static const ::capnp::_::AlignedData<19> b_b47f4979672cb59d = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_b47f4979672cb59d = { 0xb47f4979672cb59d, b_b47f4979672cb59d.words, 19, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_b47f4979672cb59d, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_b47f4979672cb59d, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_95b29059097fca83 = { @@ -262,7 +262,7 @@ static const uint16_t m_95b29059097fca83[] = {0, 1, 2}; static const uint16_t i_95b29059097fca83[] = {0, 1, 2}; const ::capnp::_::RawSchema s_95b29059097fca83 = { 0x95b29059097fca83, b_95b29059097fca83.words, 65, nullptr, m_95b29059097fca83, - 0, 3, i_95b29059097fca83, nullptr, nullptr, { &s_95b29059097fca83, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_95b29059097fca83, nullptr, nullptr, { &s_95b29059097fca83, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_9d263a3630b7ebee = { @@ -325,7 +325,7 @@ static const ::capnp::_::AlignedData<65> b_9d263a3630b7ebee = { 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, @@ -338,7 +338,7 @@ static const uint16_t m_9d263a3630b7ebee[] = {2, 0, 1}; static const uint16_t i_9d263a3630b7ebee[] = {0, 1, 2}; const ::capnp::_::RawSchema s_9d263a3630b7ebee = { 0x9d263a3630b7ebee, b_9d263a3630b7ebee.words, 65, nullptr, m_9d263a3630b7ebee, - 0, 3, i_9d263a3630b7ebee, nullptr, nullptr, { &s_9d263a3630b7ebee, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_9d263a3630b7ebee, nullptr, nullptr, { &s_9d263a3630b7ebee, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE } // namespace schemas @@ -351,51 +351,75 @@ namespace rpc { namespace twoparty { // VatId +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t VatId::_capnpPrivate::dataWordSize; constexpr uint16_t VatId::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind VatId::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* VatId::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // ProvisionId +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t ProvisionId::_capnpPrivate::dataWordSize; constexpr uint16_t ProvisionId::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind ProvisionId::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* ProvisionId::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // RecipientId +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t RecipientId::_capnpPrivate::dataWordSize; constexpr uint16_t RecipientId::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind RecipientId::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* RecipientId::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // ThirdPartyCapId +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t ThirdPartyCapId::_capnpPrivate::dataWordSize; constexpr uint16_t ThirdPartyCapId::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind ThirdPartyCapId::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* ThirdPartyCapId::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // JoinKeyPart +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t JoinKeyPart::_capnpPrivate::dataWordSize; constexpr uint16_t JoinKeyPart::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind JoinKeyPart::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* JoinKeyPart::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // JoinResult +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t JoinResult::_capnpPrivate::dataWordSize; constexpr uint16_t JoinResult::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind JoinResult::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* JoinResult::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp.h index d447706c922..d3d8153a8b7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.capnp.h @@ -5,8 +5,13 @@ #include #include +#if !CAPNP_LITE +#include +#endif // !CAPNP_LITE -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif @@ -532,7 +537,9 @@ class JoinResult::Reader { inline bool getSucceeded() const; inline bool hasCap() const; - inline ::capnp::AnyPointer::Reader getCap() const; +#if !CAPNP_LITE + inline ::capnp::Capability::Client getCap() const; +#endif // !CAPNP_LITE private: ::capnp::_::StructReader _reader; @@ -569,8 +576,13 @@ class JoinResult::Builder { inline void setSucceeded(bool value); inline bool hasCap(); - inline ::capnp::AnyPointer::Builder getCap(); - inline ::capnp::AnyPointer::Builder initCap(); +#if !CAPNP_LITE + inline ::capnp::Capability::Client getCap(); + inline void setCap( ::capnp::Capability::Client&& value); + inline void setCap( ::capnp::Capability::Client& value); + inline void adoptCap(::capnp::Orphan< ::capnp::Capability>&& value); + inline ::capnp::Orphan< ::capnp::Capability> disownCap(); +#endif // !CAPNP_LITE private: ::capnp::_::StructBuilder _builder; @@ -590,6 +602,7 @@ class JoinResult::Pipeline { inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) : _typeless(kj::mv(typeless)) {} + inline ::capnp::Capability::Client getCap(); private: ::capnp::AnyPointer::Pipeline _typeless; friend class ::capnp::PipelineHook; @@ -706,20 +719,36 @@ inline bool JoinResult::Builder::hasCap() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::AnyPointer::Reader JoinResult::Reader::getCap() const { - return ::capnp::AnyPointer::Reader(_reader.getPointerField( +#if !CAPNP_LITE +inline ::capnp::Capability::Client JoinResult::Reader::getCap() const { + return ::capnp::_::PointerHelpers< ::capnp::Capability>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::AnyPointer::Builder JoinResult::Builder::getCap() { - return ::capnp::AnyPointer::Builder(_builder.getPointerField( +inline ::capnp::Capability::Client JoinResult::Builder::getCap() { + return ::capnp::_::PointerHelpers< ::capnp::Capability>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::AnyPointer::Builder JoinResult::Builder::initCap() { - auto result = ::capnp::AnyPointer::Builder(_builder.getPointerField( +inline ::capnp::Capability::Client JoinResult::Pipeline::getCap() { + return ::capnp::Capability::Client(_typeless.getPointerField(0).asCap()); +} +inline void JoinResult::Builder::setCap( ::capnp::Capability::Client&& cap) { + ::capnp::_::PointerHelpers< ::capnp::Capability>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(cap)); +} +inline void JoinResult::Builder::setCap( ::capnp::Capability::Client& cap) { + ::capnp::_::PointerHelpers< ::capnp::Capability>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), cap); +} +inline void JoinResult::Builder::adoptCap( + ::capnp::Orphan< ::capnp::Capability>&& value) { + ::capnp::_::PointerHelpers< ::capnp::Capability>::adopt(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Capability> JoinResult::Builder::disownCap() { + return ::capnp::_::PointerHelpers< ::capnp::Capability>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); - result.clear(); - return result; } +#endif // !CAPNP_LITE } // namespace } // namespace diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.h b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.h index 58fed747615..c280e62d401 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc-twoparty.h @@ -22,7 +22,7 @@ #pragma once #include "rpc.h" -#include "message.h" +#include #include #include #include @@ -79,7 +79,8 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, // clock is used for calculating the oldest queued message age, which is a useful metric for // detecting queue overload - KJ_DISALLOW_COPY(TwoPartyVatNetwork); + ~TwoPartyVatNetwork() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(TwoPartyVatNetwork); kj::Promise onDisconnect() { return disconnectPromise.addBranch(); } // Returns a promise that resolves when the peer disconnects. @@ -90,7 +91,7 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, // Get the number of bytes worth of outgoing messages that are currently queued in memory waiting // to be sent on this connection. This may be useful for backpressure. - size_t getCurrentQueueCount() { return currentQueueCount; } + size_t getCurrentQueueCount() { return queuedMessages.size(); } // Get the count of outgoing messages that are currently queued in memory waiting // to be sent on this connection. This may be useful for backpressure. @@ -135,8 +136,8 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, kj::ForkedPromise disconnectPromise = nullptr; + kj::Vector> queuedMessages; size_t currentQueueSize = 0; - size_t currentQueueCount = 0; const kj::MonotonicClock& clock; kj::TimePoint currentOutgoingMessageSendTime; @@ -183,10 +184,12 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, class TwoPartyServer: private kj::TaskSet::ErrorHandler { // Convenience class which implements a simple server which accepts connections on a listener - // socket and serices them as two-party connections. + // socket and services them as two-party connections. public: - explicit TwoPartyServer(Capability::Client bootstrapInterface); + explicit TwoPartyServer(Capability::Client bootstrapInterface, + kj::Maybe> traceEncoder = nullptr); + // `traceEncoder`, if provided, will be passed on to `rpcSystem.setTraceEncoder()`. void accept(kj::Own&& connection); void accept(kj::Own&& connection, uint maxFdsPerMessage); @@ -220,6 +223,7 @@ class TwoPartyServer: private kj::TaskSet::ErrorHandler { private: Capability::Client bootstrapInterface; + kj::Maybe> traceEncoder; kj::TaskSet tasks; struct AcceptedConnection; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.c++ index a75fd72199e..6118e3166de 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.c++ @@ -112,8 +112,16 @@ Orphan> fromPipelineOps( } kj::Exception toException(const rpc::Exception::Reader& exception) { + auto reason = [&]() { + if (exception.getReason().startsWith("remote exception: ")) { + return kj::str(exception.getReason()); + } else { + return kj::str("remote exception: ", exception.getReason()); + } + }(); + kj::Exception result(static_cast(exception.getType()), - "(remote)", 0, kj::str("remote exception: ", exception.getReason())); + "(remote)", 0, kj::mv(reason)); if (exception.hasTrace()) { result.setRemoteTrace(kj::str(exception.getTrace())); } @@ -157,15 +165,33 @@ uint exceptionSizeHint(const kj::Exception& exception) { return sizeInWords() + exception.getDescription().size() / sizeof(word) + 1; } +ClientHook::CallHints callHintsFromReader(rpc::Call::Reader reader) { + ClientHook::CallHints hints; + hints.noPromisePipelining = reader.getNoPromisePipelining(); + hints.onlyPromisePipeline = reader.getOnlyPromisePipeline(); + return hints; +} + // ======================================================================================= +template +static constexpr Id highBit() { + return 1u << (sizeof(Id) * 8 - 1); +} + template class ExportTable { // Table mapping integers to T, where the integers are chosen locally. public: + bool isHigh(Id& id) { + return (id & highBit()) != 0; + } + kj::Maybe find(Id id) { - if (id < slots.size() && slots[id] != nullptr) { + if (isHigh(id)) { + return highSlots.find(id); + } else if (id < slots.size() && slots[id] != nullptr) { return slots[id]; } else { return nullptr; @@ -178,16 +204,23 @@ public: // `entry` is a reference to the entry being released -- we require this in order to prove // that the caller has already done a find() to check that this entry exists. We can't check // ourselves because the caller may have nullified the entry in the meantime. - KJ_DREQUIRE(&entry == &slots[id]); - T toRelease = kj::mv(slots[id]); - slots[id] = T(); - freeIds.push(id); - return toRelease; + + if (isHigh(id)) { + auto& slot = KJ_REQUIRE_NONNULL(highSlots.findEntry(id)); + return highSlots.release(slot).value; + } else { + KJ_DREQUIRE(&entry == &slots[id]); + T toRelease = kj::mv(slots[id]); + slots[id] = T(); + freeIds.push(id); + return toRelease; + } } T& next(Id& id) { if (freeIds.empty()) { id = slots.size(); + KJ_ASSERT(!isHigh(id), "2^31 concurrent questions?!!?!"); return slots.add(); } else { id = freeIds.top(); @@ -196,6 +229,25 @@ public: } } + T& nextHigh(Id& id) { + // Choose an ID with the top bit set in round-robin fashion, but don't choose an ID that + // is still in use. + + KJ_ASSERT(highSlots.size() < Id(kj::maxValue) / 2); // avoid infinite loop below. + + bool created = false; + T* slot; + while (!created) { + id = highCounter++ | highBit(); + slot = &highSlots.findOrCreate(id, [&]() { + created = true; + return typename kj::HashMap::Entry { id, T() }; + }); + } + + return *slot; + } + template void forEach(Func&& func) { for (Id i = 0; i < slots.size(); i++) { @@ -203,11 +255,24 @@ public: func(i, slots[i]); } } + for (auto& slot: highSlots) { + func(slot.key, slot.value); + } + } + + void release() { + // Release memory backing the table. + { auto drop = kj::mv(slots); } + { auto drop = kj::mv(freeIds); } + { auto drop = kj::mv(highSlots); } } private: kj::Vector slots; std::priority_queue, std::greater> freeIds; + + kj::HashMap highSlots; + Id highCounter = 0; }; template @@ -325,6 +390,12 @@ public: } void disconnect(kj::Exception&& exception) { + // Shut down the connection with the given error. + // + // This will cancel `tasks`, so cannot be called from inside a task in `tasks`. Instead, use + // `tasks.add(exception)` to schedule a shutdown, since any error thrown by a task will be + // passed to `disconnect()` later. + // After disconnect(), the RpcSystem could be destroyed, making `traceEncoder` a dangling // reference, so null it out before we return from here. We don't need it anymore once // disconnected anyway. @@ -350,33 +421,47 @@ public: // all future calls on this connection. networkException.addTraceHere(); + // Set our connection state to Disconnected now so that no one tries to write any messages to + // it in their destructors. + auto dyingConnection = kj::mv(connection.get()); + connection.init(kj::cp(networkException)); + KJ_IF_MAYBE(newException, kj::runCatchingExceptions([&]() { // Carefully pull all the objects out of the tables prior to releasing them because their // destructors could come back and mess with the tables. kj::Vector> pipelinesToRelease; kj::Vector> clientsToRelease; - kj::Vector>> tailCallsToRelease; + kj::Vector tasksToRelease; kj::Vector> resolveOpsToRelease; + KJ_DEFER(tasks.clear()); // All current questions complete with exceptions. questions.forEach([&](QuestionId id, Question& question) { KJ_IF_MAYBE(questionRef, question.selfRef) { // QuestionRef still present. questionRef->reject(kj::cp(networkException)); + + // We need to fully disconnect each QuestionRef otherwise it holds a reference back to + // the connection state. Meanwhile `tasks` may hold streaming calls that end up holding + // these QuestionRefs. Technically this is a cyclic reference, but as long as the cycle + // is broken on disconnect (which happens when the RpcSystem itself is destroyed), then + // we're OK. + questionRef->disconnect(); } }); + // Since we've disconnected the QuestionRefs, they won't clean up the questions table for + // us, so do that here. + questions.release(); answers.forEach([&](AnswerId id, Answer& answer) { KJ_IF_MAYBE(p, answer.pipeline) { pipelinesToRelease.add(kj::mv(*p)); } - KJ_IF_MAYBE(promise, answer.redirectedResults) { - tailCallsToRelease.add(kj::mv(*promise)); - } + tasksToRelease.add(kj::mv(answer.task)); KJ_IF_MAYBE(context, answer.callContext) { - context->requestCancel(); + context->finish(); } }); @@ -408,31 +493,35 @@ public: // Send an abort message, but ignore failure. kj::runCatchingExceptions([&]() { - auto message = connection.get()->newOutgoingMessage( + auto message = dyingConnection->newOutgoingMessage( messageSizeHint() + exceptionSizeHint(exception)); fromException(exception, message->getBody().getAs().initAbort()); message->send(); }); // Indicate disconnect. - auto shutdownPromise = connection.get()->shutdown() - .attach(kj::mv(connection.get())) + auto shutdownPromise = dyingConnection->shutdown() + .attach(kj::mv(dyingConnection)) .then([]() -> kj::Promise { return kj::READY_NOW; }, - [origException = kj::mv(exception)](kj::Exception&& e) -> kj::Promise { + [this, origException = kj::mv(exception)](kj::Exception&& shutdownException) -> kj::Promise { // Don't report disconnects as an error. - if (e.getType() == kj::Exception::Type::DISCONNECTED) { + if (shutdownException.getType() == kj::Exception::Type::DISCONNECTED) { return kj::READY_NOW; } // If the error is just what was passed in to disconnect(), don't report it back out // since it shouldn't be anything the caller doesn't already know about. - if (e.getType() == origException.getType() && - e.getDescription() == origException.getDescription()) { + if (shutdownException.getType() == origException.getType() && + shutdownException.getDescription() == origException.getDescription()) { return kj::READY_NOW; } - return kj::mv(e); + // We are shutting down after receive error, ignore shutdown exception since underlying + // transport is probably broken. + if (receiveIncomingMessageError) { + return kj::READY_NOW; + } + return kj::mv(shutdownException); }); disconnectFulfiller->fulfill(DisconnectInfo { kj::mv(shutdownPromise) }); - connection.init(kj::mv(networkException)); canceler.cancel(networkException); } @@ -486,6 +575,11 @@ private: bool skipFinish = false; // If true, don't send a Finish message. + // + // This is used in two cases: + // * The `Return` message had the `noFinishNeeded` hint. + // * Our attempt to send the `Call` threw an exception, therefore the peer never even received + // the call in the first place and would not expect a `Finish`. inline bool operator==(decltype(nullptr)) const { return !isAwaitingReturn && selfRef == nullptr; @@ -507,7 +601,17 @@ private: kj::Maybe> pipeline; // Send pipelined calls here. Becomes null as soon as a `Finish` is received. - kj::Maybe>> redirectedResults; + using Running = kj::Promise; + struct Finished {}; + using Redirected = kj::Promise>; + + kj::OneOf task; + // While the RPC is running locally, `task` is a `Promise` representing the task to execute + // the RPC. + // + // When `Finish` is received (and results are not redirected), `task` becomes `Finished`, which + // cancels it if it's still running. + // // For locally-redirected calls (Call.sendResultsTo.yourself), this is a promise for the call // result, to be picked up by a subsequent `Return`. @@ -609,6 +713,19 @@ private: kj::TaskSet tasks; + bool gotReturnForHighQuestionId = false; + // Becomes true if we ever get a `Return` message for a high question ID (with top bit set), + // which we use in cases where we've hinted to the peer that we don't want a `Return`. If the + // peer sends us one anyway then it seemingly doesn't not implement our hints. We need to stop + // using the hints in this case before the high question ID space wraps around since otherwise + // we might reuse an ID that the peer thinks is still in use. + + bool sentCapabilitiesInPipelineOnlyCall = false; + // Becomes true if `sendPipelineOnly()` is ever called with parameters that include capabilities. + + bool receiveIncomingMessageError = false; + // Becomes true when receiveIncomingMessage resulted in exception. + // ===================================================================================== // ClientHook implementations @@ -617,6 +734,13 @@ private: RpcClient(RpcConnectionState& connectionState) : connectionState(kj::addRef(connectionState)) {} + ~RpcClient() noexcept(false) { + KJ_IF_MAYBE(f, this->flowController) { + // Destroying the client should not cancel outstanding streaming calls. + connectionState->tasks.add(f->get()->waitAllAcked().attach(kj::mv(*f))); + } + } + virtual kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor, kj::Vector& fds) = 0; // Writes a CapDescriptor referencing this client. The CapDescriptor must be sent as part of @@ -663,12 +787,14 @@ private: // implements ClientHook ----------------------------------------- Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { - return newCallNoIntercept(interfaceId, methodId, sizeHint); + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { + return newCallNoIntercept(interfaceId, methodId, sizeHint, hints); } Request newCallNoIntercept( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) { if (!connectionState->connection.is()) { return newBrokenRequest(kj::cp(connectionState->connection.get()), sizeHint); } @@ -680,29 +806,28 @@ private: callBuilder.setInterfaceId(interfaceId); callBuilder.setMethodId(methodId); + callBuilder.setNoPromisePipelining(hints.noPromisePipelining); + callBuilder.setOnlyPromisePipeline(hints.onlyPromisePipeline); auto root = request->getRoot(); return Request(root, kj::mv(request)); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { - return callNoIntercept(interfaceId, methodId, kj::mv(context)); + kj::Own&& context, CallHints hints) override { + return callNoIntercept(interfaceId, methodId, kj::mv(context), hints); } VoidPromiseAndPipeline callNoIntercept(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) { + kj::Own&& context, CallHints hints) { // Implement call() by copying params and results messages. auto params = context->getParams(); - auto request = newCallNoIntercept(interfaceId, methodId, params.targetSize()); + auto request = newCallNoIntercept(interfaceId, methodId, params.targetSize(), hints); request.set(params); context->releaseParams(); - // We can and should propagate cancellation. - context->allowCancellation(); - return context->directTailCall(RequestHook::from(kj::mv(request))); } @@ -923,19 +1048,20 @@ private: // implements ClientHook ----------------------------------------- Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { receivedCall = true; // IMPORTANT: We must call our superclass's version of newCall(), NOT cap->newCall(), because // the Request object we create needs to check at send() time whether the promise has // resolved and, if so, redirect to the new target. - return RpcClient::newCall(interfaceId, methodId, sizeHint); + return RpcClient::newCall(interfaceId, methodId, sizeHint, hints); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { + kj::Own&& context, CallHints hints) override { receivedCall = true; - return cap->call(interfaceId, methodId, kj::mv(context)); + return cap->call(interfaceId, methodId, kj::mv(context), hints); } kj::Maybe getResolved() override { @@ -1029,7 +1155,7 @@ private: if (other->isResolved()) { // The other capability resolved already. If it determined that it resolved as - // relfected, then we determine the same. + // reflected, then we determine the same. resolutionType = other->resolutionType; } else { // The other capability hasn't resolved yet, so we can safely merge with it and do a @@ -1379,7 +1505,7 @@ private: // // 2. The `Resolve` message contained a `CapDescriptor` of type `receiverHosted`, naming an // entry in the receiver's export table. That entry just happened to contain an - // `ImportClient` refering back to the sender. This specifically happens when the entry + // `ImportClient` referring back to the sender. This specifically happens when the entry // in question had previously itself referred to a promise, and that promise has since // resolved to a remote capability, at which point the export table entry was replaced by // the appropriate `ImportClient` representing that. Presumably, the peer *did not yet know* @@ -1416,12 +1542,13 @@ private: TribbleRaceBlocker(kj::Own inner): inner(kj::mv(inner)) {} Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { - return inner->newCall(interfaceId, methodId, sizeHint); + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { + return inner->newCall(interfaceId, methodId, sizeHint, hints); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { - return inner->call(interfaceId, methodId, kj::mv(context)); + kj::Own&& context, CallHints hints) override { + return inner->call(interfaceId, methodId, kj::mv(context), hints); } kj::Maybe getResolved() override { // We always wrap either PipelineClient or ImportClient, both of which return null for this @@ -1526,7 +1653,7 @@ private: public: inline QuestionRef( RpcConnectionState& connectionState, QuestionId id, - kj::Own>>> fulfiller) + kj::Maybe>>>> fulfiller) : connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)) {} ~QuestionRef() noexcept { @@ -1534,57 +1661,75 @@ private: // throws (without being caught) we're probably in pretty bad shape and going to be crashing // later anyway. Better to abort now. - auto& question = KJ_ASSERT_NONNULL( - connectionState->questions.find(id), "Question ID no longer on table?"); + KJ_IF_MAYBE(c, connectionState) { + auto& connectionState = *c; - // Send the "Finish" message (if the connection is not already broken). - if (connectionState->connection.is() && !question.skipFinish) { - KJ_IF_MAYBE(e, kj::runCatchingExceptions([&]() { - auto message = connectionState->connection.get()->newOutgoingMessage( - messageSizeHint()); - auto builder = message->getBody().getAs().initFinish(); - builder.setQuestionId(id); - // If we're still awaiting a return, then this request is being canceled, and we're going - // to ignore any capabilities in the return message, so set releaseResultCaps true. If we - // already received the return, then we've already built local proxies for the caps and - // will send Release messages when those are destroyed. - builder.setReleaseResultCaps(question.isAwaitingReturn); - message->send(); - })) { - connectionState->disconnect(kj::mv(*e)); + auto& question = KJ_ASSERT_NONNULL( + connectionState->questions.find(id), "Question ID no longer on table?"); + + // Send the "Finish" message (if the connection is not already broken). + if (connectionState->connection.is() && !question.skipFinish) { + KJ_IF_MAYBE(e, kj::runCatchingExceptions([&]() { + auto message = connectionState->connection.get()->newOutgoingMessage( + messageSizeHint()); + auto builder = message->getBody().getAs().initFinish(); + builder.setQuestionId(id); + // If we're still awaiting a return, then this request is being canceled, and we're going + // to ignore any capabilities in the return message, so set releaseResultCaps true. If we + // already received the return, then we've already built local proxies for the caps and + // will send Release messages when those are destroyed. + builder.setReleaseResultCaps(question.isAwaitingReturn); + + // Let the peer know we don't have the early cancellation bug. + builder.setRequireEarlyCancellationWorkaround(false); + + message->send(); + })) { + connectionState->tasks.add(kj::mv(*e)); + } } - } - // Check if the question has returned and, if so, remove it from the table. - // Remove question ID from the table. Must do this *after* sending `Finish` to ensure that - // the ID is not re-allocated before the `Finish` message can be sent. - if (question.isAwaitingReturn) { - // Still waiting for return, so just remove the QuestionRef pointer from the table. - question.selfRef = nullptr; - } else { - // Call has already returned, so we can now remove it from the table. - connectionState->questions.erase(id, question); + // Check if the question has returned and, if so, remove it from the table. + // Remove question ID from the table. Must do this *after* sending `Finish` to ensure that + // the ID is not re-allocated before the `Finish` message can be sent. + if (question.isAwaitingReturn) { + // Still waiting for return, so just remove the QuestionRef pointer from the table. + question.selfRef = nullptr; + } else { + // Call has already returned, so we can now remove it from the table. + connectionState->questions.erase(id, question); + } } } inline QuestionId getId() const { return id; } void fulfill(kj::Own&& response) { - fulfiller->fulfill(kj::mv(response)); + KJ_IF_MAYBE(f, fulfiller) { + f->get()->fulfill(kj::mv(response)); + } } void fulfill(kj::Promise>&& promise) { - fulfiller->fulfill(kj::mv(promise)); + KJ_IF_MAYBE(f, fulfiller) { + f->get()->fulfill(kj::mv(promise)); + } } void reject(kj::Exception&& exception) { - fulfiller->reject(kj::mv(exception)); + KJ_IF_MAYBE(f, fulfiller) { + f->get()->reject(kj::mv(exception)); + } + } + + void disconnect() { + connectionState = nullptr; } private: - kj::Own connectionState; + kj::Maybe> connectionState; QuestionId id; - kj::Own>>> fulfiller; + kj::Maybe>>>> fulfiller; }; class RpcRequest final: public RequestHook { @@ -1609,6 +1754,7 @@ private: RemotePromise send() override { if (!connectionState->connection.is()) { // Connection is broken. + // TODO(bug): Seems like we should check for redirect before this? const kj::Exception& e = connectionState->connection.get(); return RemotePromise( kj::Promise>(kj::cp(e)), @@ -1620,19 +1766,29 @@ private: // We'll have to make a new request and do a copy. Ick. auto replacement = redirect->get()->newCall( - callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize()); + callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize(), + callHintsFromReader(callBuilder)); replacement.set(paramsBuilder); return replacement.send(); } else { + bool noPromisePipelining = callBuilder.getNoPromisePipelining(); + auto sendResult = sendInternal(false); - auto forkedPromise = sendResult.promise.fork(); + kj::Own pipeline; + if (noPromisePipelining) { + pipeline = getDisabledPipeline(); + } else { + auto forkedPromise = sendResult.promise.fork(); - // The pipeline must get notified of resolution before the app does to maintain ordering. - auto pipeline = kj::refcounted( - *connectionState, kj::mv(sendResult.questionRef), forkedPromise.addBranch()); + // The pipeline must get notified of resolution before the app does to maintain ordering. + pipeline = kj::refcounted( + *connectionState, kj::mv(sendResult.questionRef), forkedPromise.addBranch()); - auto appPromise = forkedPromise.addBranch().then( + sendResult.promise = forkedPromise.addBranch(); + } + + auto appPromise = sendResult.promise.then( [=](kj::Own&& response) { auto reader = response->getResults(); return Response(reader, kj::mv(response)); @@ -1647,6 +1803,7 @@ private: kj::Promise sendStreaming() override { if (!connectionState->connection.is()) { // Connection is broken. + // TODO(bug): Seems like we should check for redirect before this? return kj::cp(connectionState->connection.get()); } @@ -1655,7 +1812,8 @@ private: // We'll have to make a new request and do a copy. Ick. auto replacement = redirect->get()->newCall( - callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize()); + callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize(), + callHintsFromReader(callBuilder)); replacement.set(paramsBuilder); return RequestHook::from(kj::mv(replacement))->sendStreaming(); } else { @@ -1663,6 +1821,34 @@ private: } } + AnyPointer::Pipeline sendForPipeline() override { + if (!connectionState->connection.is()) { + // Connection is broken. + // TODO(bug): Seems like we should check for redirect before this? + const kj::Exception& e = connectionState->connection.get(); + return AnyPointer::Pipeline(newBrokenPipeline(kj::cp(e))); + } + + KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) { + // Whoops, this capability has been redirected while we were building the request! + // We'll have to make a new request and do a copy. Ick. + + auto replacement = redirect->get()->newCall( + callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize(), + callHintsFromReader(callBuilder)); + replacement.set(paramsBuilder); + return replacement.sendForPipeline(); + } else if (connectionState->gotReturnForHighQuestionId) { + // Peer doesn't implement our hints. Fall back to a regular send(). + return send(); + } else { + auto questionRef = sendForPipelineInternal(); + kj::Own pipeline = kj::refcounted( + *connectionState, kj::mv(questionRef)); + return AnyPointer::Pipeline(kj::mv(pipeline)); + } + } + struct TailInfo { QuestionId questionId; kj::Promise promise; @@ -1697,7 +1883,13 @@ private: QuestionId questionId = sendResult.questionRef->getId(); - auto pipeline = kj::refcounted(*connectionState, kj::mv(sendResult.questionRef)); + kj::Own pipeline; + bool noPromisePipelining = callBuilder.getNoPromisePipelining(); + if (noPromisePipelining) { + pipeline = getDisabledPipeline(); + } else { + pipeline = kj::refcounted(*connectionState, kj::mv(sendResult.questionRef)); + } return TailInfo { questionId, kj::mv(promise), kj::mv(pipeline) }; } @@ -1769,6 +1961,9 @@ private: })) { // We can't safely throw the exception from here since we've already modified the question // table state. We'll have to reject the promise instead. + // TODO(bug): Attempts to use the pipeline will end up sending a request referencing a + // bogus question ID. Can we rethrow after doing the appropriate cleanup, so the pipeline + // is never created? See the approach in sendForPipelineInternal() below. result.question.isAwaitingReturn = false; result.question.skipFinish = true; connectionState->releaseExports(result.question.paramExports); @@ -1810,6 +2005,47 @@ private: return kj::mv(flowPromise); } + + kj::Own sendForPipelineInternal() { + // Since must of setupSend() is subtly different for this case, we don't reuse it. + + // Build the cap table. + kj::Vector fds; + auto exports = connectionState->writeDescriptors( + capTable.getTable(), callBuilder.getParams(), fds); + message->setFds(fds.releaseAsArray()); + + if (exports.size() > 0) { + connectionState->sentCapabilitiesInPipelineOnlyCall = true; + } + + // Init the question table. Do this after writing descriptors to avoid interference. + QuestionId questionId; + auto& question = connectionState->questions.nextHigh(questionId); + question.isAwaitingReturn = false; // No Return needed + question.paramExports = kj::mv(exports); + question.isTailCall = false; + + // Make the QuentionRef and result promise. + auto questionRef = kj::refcounted(*connectionState, questionId, nullptr); + question.selfRef = *questionRef; + + // If sending throws, we'll need to fix up the state a little... + KJ_ON_SCOPE_FAILURE({ + question.skipFinish = true; + connectionState->releaseExports(question.paramExports); + }); + + // Finish and send. + callBuilder.setQuestionId(questionId); + callBuilder.setOnlyPromisePipeline(true); + + KJ_CONTEXT("sending RPC call", + callBuilder.getInterfaceId(), callBuilder.getMethodId()); + message->send(); + + return kj::mv(questionRef); + } }; class RpcPipeline final: public PipelineHook, public kj::Refcounted { @@ -1978,6 +2214,10 @@ private: return capTable.imbue(payload.getContent()); } + inline bool hasCapabilities() { + return capTable.getTable().size() > 0; + } + kj::Maybe> send() { // Send the response and return the export list. Returns nullptr if there were no caps. // (Could return a non-null empty array if there were caps but none of them were exports.) @@ -1988,14 +2228,16 @@ private: auto exports = connectionState.writeDescriptors(capTable, payload, fds); message->setFds(fds.releaseAsArray()); - // Capabilities that we are returning are subject to embargos. See `Disembargo` in rpc.capnp. - // As explained there, in order to deal with the Tribble 4-way race condition, we need to - // make sure that if we're returning any remote promises, that we ignore any subsequent - // resolution of those promises for the purpose of pipelined requests on this answer. Luckily, - // we can modify the cap table in-place. + // Populate `resolutionsAtReturnTime`. for (auto& slot: capTable) { KJ_IF_MAYBE(cap, slot) { - slot = connectionState.getInnermostClient(**cap); + auto inner = connectionState.getInnermostClient(**cap); + if (inner.get() != cap->get()) { + resolutionsAtReturnTime.upsert(cap->get(), kj::mv(inner), + [&](kj::Own& existing, kj::Own&& replacement) { + KJ_ASSERT(existing.get() == replacement.get()); + }); + } } } @@ -2007,11 +2249,40 @@ private: } } + struct Resolution { + kj::Own returnedCap; + // The capabiilty that appeared in the response message in this slot. + + kj::Own unwrapped; + // Exactly what `getInnermostClient(returnedCap)` produced at the time that the return + // message was encoded. + }; + + Resolution getResolutionAtReturnTime(kj::ArrayPtr ops) { + auto returnedCap = getResultsBuilder().asReader().getPipelinedCap(ops); + kj::Own unwrapped; + KJ_IF_MAYBE(u, resolutionsAtReturnTime.find(returnedCap.get())) { + unwrapped = u->get()->addRef(); + } else { + unwrapped = returnedCap->addRef(); + } + return { kj::mv(returnedCap), kj::mv(unwrapped) }; + } + private: RpcConnectionState& connectionState; kj::Own message; BuilderCapabilityTable capTable; rpc::Payload::Builder payload; + + kj::HashMap> resolutionsAtReturnTime; + // For each capability in `capTable` as of the time when the call returned, this map stores + // the result of calling `getInnermostClient()` on that capability. This is needed in order + // to solve the Tribble 4-way race condition described in the documentation for `Disembargo` + // in `rpc.capnp`. `PostReturnRpcPipeline`, below, uses this. + // + // As an optimization, if the innermost client is exactly the same object then nothing is + // stored in the map. }; class LocallyRedirectedRpcResponse final @@ -2037,16 +2308,87 @@ private: MallocMessageBuilder message; }; + class PostReturnRpcPipeline final: public PipelineHook, public kj::Refcounted { + // Once an incoming call has returned, we may need to replace the `PipelineHook` with one that + // correctly handles the Tribble 4-way race condition. Namely, we must ensure that if the + // response contained any capabilities pointing back out to the network, then any further + // pipelined calls received targetting those capabilities (as well as any Disembargo messages) + // will resolve to the same network capability forever, *even if* that network capability is + // itself a promise which later resolves to somewhere else. + public: + PostReturnRpcPipeline(kj::Own inner, + RpcServerResponseImpl& response, + kj::Own context) + : inner(kj::mv(inner)), response(response), context(kj::mv(context)) {} + + kj::Own addRef() override { + return kj::addRef(*this); + } + + kj::Own getPipelinedCap(kj::ArrayPtr ops) override { + auto resolved = response.getResolutionAtReturnTime(ops); + auto original = inner->getPipelinedCap(ops); + return getResolutionAtReturnTime(kj::mv(original), kj::mv(resolved)); + } + + kj::Own getPipelinedCap(kj::Array&& ops) override { + auto resolved = response.getResolutionAtReturnTime(ops); + auto original = inner->getPipelinedCap(kj::mv(ops)); + return getResolutionAtReturnTime(kj::mv(original), kj::mv(resolved)); + } + + private: + kj::Own inner; + RpcServerResponseImpl& response; + kj::Own context; // owns `response` + + kj::Own getResolutionAtReturnTime( + kj::Own original, RpcServerResponseImpl::Resolution resolution) { + // Wait for `original` to resolve to `resolution.returnedCap`, then return + // `resolution.unwrapped`. + + ClientHook* ptr = original.get(); + for (;;) { + if (ptr == resolution.returnedCap.get()) { + return kj::mv(resolution.unwrapped); + } else KJ_IF_MAYBE(r, ptr->getResolved()) { + ptr = r; + } else { + break; + } + } + + KJ_IF_MAYBE(p, ptr->whenMoreResolved()) { + return newLocalPromiseClient(p->then( + [this, original = kj::mv(original), resolution = kj::mv(resolution)] + (kj::Own r) mutable { + return getResolutionAtReturnTime(kj::mv(r), kj::mv(resolution)); + })); + } else if (ptr->isError() || ptr->isNull()) { + // This is already a broken capability, the error probably explains what went wrong. In + // any case, message ordering is irrelevant here since all calls will throw anyway. + return ptr->addRef(); + } else { + return newBrokenCap( + "An RPC call's capnp::PipelineHook object resolved a pipelined capability to a " + "different final object than what was returned in the actual response. This could " + "be a bug in Cap'n Proto, or could be due to a use of context.setPipeline() that " + "was inconsistent with the later results."); + } + } + }; + class RpcCallContext final: public CallContextHook, public kj::Refcounted { public: RpcCallContext(RpcConnectionState& connectionState, AnswerId answerId, kj::Own&& request, kj::Array>> capTableArray, const AnyPointer::Reader& params, - bool redirectResults, kj::Own>&& cancelFulfiller, - uint64_t interfaceId, uint16_t methodId) + bool redirectResults, uint64_t interfaceId, uint16_t methodId, + ClientHook::CallHints hints) : connectionState(kj::addRef(connectionState)), answerId(answerId), + hints(hints), interfaceId(interfaceId), methodId(methodId), requestSize(request->sizeInWords()), @@ -2054,8 +2396,7 @@ private: paramsCapTable(kj::mv(capTableArray)), params(paramsCapTable.imbue(params)), returnMessage(nullptr), - redirectResults(redirectResults), - cancelFulfiller(kj::mv(cancelFulfiller)) { + redirectResults(redirectResults) { connectionState.callWordsInFlight += requestSize; } @@ -2063,9 +2404,10 @@ private: if (isFirstResponder()) { // We haven't sent a return yet, so we must have been canceled. Send a cancellation return. unwindDetector.catchExceptionsIfUnwinding([&]() { - // Don't send anything if the connection is broken. + // Don't send anything if the connection is broken, or if the onlyPromisePipeline hint + // was used (in which case the caller doesn't care to receive a `Return`). bool shouldFreePipeline = true; - if (connectionState->connection.is()) { + if (connectionState->connection.is() && !hints.onlyPromisePipeline) { auto message = connectionState->connection.get()->newOutgoingMessage( messageSizeHint() + sizeInWords()); auto builder = message->getBody().initAs().initReturn(); @@ -2104,10 +2446,11 @@ private: void sendReturn() { KJ_ASSERT(!redirectResults); + KJ_ASSERT(!hints.onlyPromisePipeline); // Avoid sending results if canceled so that we don't have to figure out whether or not // `releaseResultCaps` was set in the already-received `Finish`. - if (!(cancellationFlags & CANCEL_REQUESTED) && isFirstResponder()) { + if (!receivedFinish && isFirstResponder()) { KJ_ASSERT(connectionState->connection.is(), "Cancellation should have been requested on disconnect.") { return; @@ -2118,17 +2461,43 @@ private: returnMessage.setAnswerId(answerId); returnMessage.setReleaseParamCaps(false); + auto& responseImpl = kj::downcast(*KJ_ASSERT_NONNULL(response)); + if (!responseImpl.hasCapabilities()) { + returnMessage.setNoFinishNeeded(true); + + // Tell ourselves that a finsih was already received, so that `cleanupAnswerTable()` + // removes the answer table entry. + receivedFinish = true; + + // HACK: The answer table's `task` is the thing which is calling `sendReturn()`. We can't + // cancel ourselves. However, we know calling `sendReturn()` is the last thing it does, + // so we can safely detach() it. + auto& answer = KJ_ASSERT_NONNULL(connectionState->answers.find(answerId)); + auto& selfPromise = KJ_ASSERT_NONNULL(answer.task.tryGet()); + selfPromise.detach([](kj::Exception&&) {}); + } + kj::Maybe> exports; KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { - // Debug info incase send() fails due to overside message. + // Debug info in case send() fails due to overside message. KJ_CONTEXT("returning from RPC call", interfaceId, methodId); - exports = kj::downcast(*KJ_ASSERT_NONNULL(response)).send(); + exports = responseImpl.send(); })) { responseSent = false; sendErrorReturn(kj::mv(*exception)); return; } + if (responseImpl.hasCapabilities()) { + auto& answer = KJ_ASSERT_NONNULL(connectionState->answers.find(answerId)); + // Swap out the `pipeline` in the answer table for one that will return capabilities + // consistent with whatever the result caps resolved to as of the time the return was sent. + answer.pipeline = answer.pipeline.map([&](kj::Own& inner) { + return kj::refcounted( + kj::mv(inner), responseImpl, kj::addRef(*this)); + }); + } + KJ_IF_MAYBE(e, exports) { // Caps were returned, so we can't free the pipeline yet. cleanupAnswerTable(kj::mv(*e), false); @@ -2140,6 +2509,7 @@ private: } void sendErrorReturn(kj::Exception&& exception) { KJ_ASSERT(!redirectResults); + KJ_ASSERT(!hints.onlyPromisePipeline); if (isFirstResponder()) { if (connectionState->connection.is()) { auto message = connectionState->connection.get()->newOutgoingMessage( @@ -2150,6 +2520,12 @@ private: builder.setReleaseParamCaps(false); connectionState->fromException(exception, builder.initException()); + // Note that even though the response contains no capabilities, we don't want to set + // `noFinishNeeded` here because if any pipelined calls were made, we want them to + // fail with the correct exception. (Perhaps if the request had `noPromisePipelining`, + // then we could set `noFinishNeeded`, but optimizing the error case doesn't seem that + // important.) + message->send(); } @@ -2160,6 +2536,7 @@ private: } void sendRedirectReturn() { KJ_ASSERT(redirectResults); + KJ_ASSERT(!hints.onlyPromisePipeline); if (isFirstResponder()) { auto message = connectionState->connection.get()->newOutgoingMessage( @@ -2170,27 +2547,20 @@ private: builder.setReleaseParamCaps(false); builder.setResultsSentElsewhere(); + // TODO(perf): Could `noFinishNeeded` be used here? The `Finish` messages are pretty + // redundant after a redirect, but as this case is less common and more complicated I + // don't want to fully think through the implications right now. + message->send(); cleanupAnswerTable(nullptr, false); } } - void requestCancel() { - // Hints that the caller wishes to cancel this call. At the next time when cancellation is - // deemed safe, the RpcCallContext shall send a canceled Return -- or if it never becomes - // safe, the RpcCallContext will send a normal return when the call completes. Either way - // the RpcCallContext is now responsible for cleaning up the entry in the answer table, since - // a Finish message was already received. + void finish() { + // Called when a `Finish` message is received while this object still exists. - bool previouslyAllowedButNotRequested = cancellationFlags == CANCEL_ALLOWED; - cancellationFlags |= CANCEL_REQUESTED; - - if (previouslyAllowedButNotRequested) { - // We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Initiate - // the cancellation. - cancelFulfiller->fulfill(); - } + receivedFinish = true; } // implements CallContextHook ------------------------------------ @@ -2240,9 +2610,13 @@ private: KJ_REQUIRE(response == nullptr, "Can't call tailCall() after initializing the results struct."); - if (request->getBrand() == connectionState.get() && !redirectResults) { + if (request->getBrand() == connectionState.get() && + !redirectResults && !hints.noPromisePipelining) { // The tail call is headed towards the peer that called us in the first place, so we can // optimize out the return trip. + // + // If the noPromisePipelining hint was sent, we skip this trick since the caller will + // ignore the `Return` message anyway. KJ_IF_MAYBE(tailInfo, kj::downcast(*request).tailSend()) { if (isFirstResponder()) { @@ -2267,6 +2641,14 @@ private: } // Just forwarding to another local call. + + if (hints.onlyPromisePipeline) { + return { + kj::NEVER_DONE, + PipelineHook::from(request->sendForPipeline()) + }; + } + auto promise = request->send(); // Wait for response. @@ -2284,16 +2666,6 @@ private: tailCallPipelineFulfiller = kj::mv(paf.fulfiller); return kj::mv(paf.promise); } - void allowCancellation() override { - bool previouslyRequestedButNotAllowed = cancellationFlags == CANCEL_REQUESTED; - cancellationFlags |= CANCEL_ALLOWED; - - if (previouslyRequestedButNotAllowed) { - // We just set CANCEL_ALLOWED, and CANCEL_REQUESTED was already set previously. Initiate - // the cancellation. - cancelFulfiller->fulfill(); - } - } kj::Own addRef() override { return kj::addRef(*this); } @@ -2302,6 +2674,8 @@ private: kj::Own connectionState; AnswerId answerId; + ClientHook::CallHints hints; + uint64_t interfaceId; uint16_t methodId; // For debugging. @@ -2323,18 +2697,9 @@ private: // Cancellation state ---------------------------------- - enum CancellationFlags { - CANCEL_REQUESTED = 1, - CANCEL_ALLOWED = 2 - }; - - uint8_t cancellationFlags = 0; - // When both flags are set, the cancellation process will begin. - - kj::Own> cancelFulfiller; - // Fulfilled when cancellation has been both requested and permitted. The fulfilled promise is - // exclusive-joined with the outermost promise waiting on the call return, so fulfilling it - // cancels that promise. + bool receivedFinish = false; + // True if a `Finish` message has been recevied OR we sent a `Return` with `noFinishNedeed`. + // In either case, it is our responsibility to clean up the answer table. kj::UnwindDetector unwindDetector; @@ -2354,7 +2719,7 @@ private: // answer table. Or we might even be responsible for removing the entire answer table // entry. - if (cancellationFlags & CANCEL_REQUESTED) { + if (receivedFinish) { // Already received `Finish` so it's our job to erase the table entry. We shouldn't have // sent results if canceled, so we shouldn't have an export list to deal with. KJ_ASSERT(resultExports.size() == 0); @@ -2410,9 +2775,13 @@ private: handleMessage(kj::mv(*m)); return true; } else { - disconnect(KJ_EXCEPTION(DISCONNECTED, "Peer disconnected.")); + tasks.add(KJ_EXCEPTION(DISCONNECTED, "Peer disconnected.")); return false; } + }, [this](kj::Exception&& exception) { + receiveIncomingMessageError = true; + kj::throwRecoverableException(kj::mv(exception)); + return false; }).then([this](bool keepGoing) { // No exceptions; continue loop. // @@ -2421,13 +2790,13 @@ private: // // TODO(perf): We add an evalLater() here so that anything we needed to do in reaction to // the previous message has a chance to complete before the next message is handled. In - // paticular, without this, I observed an ordering problem: I saw a case where a `Return` + // particular, without this, I observed an ordering problem: I saw a case where a `Return` // message was followed by a `Resolve` message, but the `PromiseClient` associated with the // `Resolve` had its `resolve()` method invoked _before_ any `PromiseClient`s associated // with pipelined capabilities resolved by the `Return`. This could lead to an // incorrectly-ordered interaction between `PromiseClient`s when they resolve to each // other. This is probably really a bug in the way `Return`s are handled -- apparently, - // resolution of `PromiseClient`s based on returned capabilites does not occur in a + // resolution of `PromiseClient`s based on returned capabilities does not occur in a // depth-first way, when it should. If we could fix that then we can probably remove this // `evalLater()`. However, the `evalLater()` is not that bad and solves the problem... if (keepGoing) tasks.add(kj::evalLater([this]() { return messageLoop(); })); @@ -2598,7 +2967,14 @@ private: kj::Vector fds; resultExports = writeDescriptors(capTableArray, payload, fds); response->setFds(fds.releaseAsArray()); - capHook = KJ_ASSERT_NONNULL(capTableArray[0])->addRef(); + + // If we're returning a capability that turns out to be an PromiseClient pointing back on + // this same network, it's important we remove the `PromiseClient` layer and use the inner + // capability instead. This achieves the same effect that `PostReturnRpcPipeline` does for + // regular call returns. + // + // This single line of code represents two hours of my life. + capHook = getInnermostClient(*KJ_ASSERT_NONNULL(capTableArray[0])); })) { fromException(*exception, ret.initException()); capHook = newBrokenCap(kj::mv(*exception)); @@ -2643,14 +3019,18 @@ private: auto payload = call.getParams(); auto capTableArray = receiveCaps(payload.getCapTable(), message->getAttachedFds()); - auto cancelPaf = kj::newPromiseAndFulfiller(); AnswerId answerId = call.getQuestionId(); + auto hints = callHintsFromReader(call); + + // Don't honor onlyPromisePipeline if results are redirected, because this situation isn't + // useful in practice and would be complicated to handle "correctly". + if (redirectResults) hints.onlyPromisePipeline = false; + auto context = kj::refcounted( *this, answerId, kj::mv(message), kj::mv(capTableArray), payload.getContent(), - redirectResults, kj::mv(cancelPaf.fulfiller), - call.getInterfaceId(), call.getMethodId()); + redirectResults, call.getInterfaceId(), call.getMethodId(), hints); // No more using `call` after this point, as it now belongs to the context. @@ -2666,7 +3046,7 @@ private: } auto promiseAndPipeline = startCall( - call.getInterfaceId(), call.getMethodId(), kj::mv(capability), context->addRef()); + call.getInterfaceId(), call.getMethodId(), kj::mv(capability), context->addRef(), hints); // Things may have changed -- in particular if startCall() immediately called // context->directTailCall(). @@ -2677,45 +3057,38 @@ private: answer.pipeline = kj::mv(promiseAndPipeline.pipeline); if (redirectResults) { - auto resultsPromise = promiseAndPipeline.promise.then( - kj::mvCapture(context, [](kj::Own&& context) { + answer.task = promiseAndPipeline.promise.then( + [context=kj::mv(context)]() mutable { return context->consumeRedirectedResponse(); - })); - - // If the call that later picks up `redirectedResults` decides to discard it, we need to - // make sure our call is not itself canceled unless it has called allowCancellation(). - // So we fork the promise and join one branch with the cancellation promise, in order to - // hold on to it. - auto forked = resultsPromise.fork(); - answer.redirectedResults = forked.addBranch(); - - cancelPaf.promise - .exclusiveJoin(forked.addBranch().then([](kj::Own&&){})) - .detach([](kj::Exception&&) {}); + }); + } else if (hints.onlyPromisePipeline) { + // The promise is probably fake anyway, so don't bother adding a .then(). We do, however, + // have to attach `context` to this, since we destroy `task` upon receiving a `Finish` + // message, and we want `RpcCallContext` to be destroyed no earlier than that. + answer.task = promiseAndPipeline.promise.attach(kj::mv(context)); } else { // Hack: Both the success and error continuations need to use the context. We could // refcount, but both will be destroyed at the same time anyway. - RpcCallContext* contextPtr = context; - - promiseAndPipeline.promise.then( - [contextPtr]() { - contextPtr->sendReturn(); - }, [contextPtr](kj::Exception&& exception) { - contextPtr->sendErrorReturn(kj::mv(exception)); - }).catch_([&](kj::Exception&& exception) { + RpcCallContext& contextRef = *context; + + answer.task = promiseAndPipeline.promise.then( + [context = kj::mv(context)]() mutable { + context->sendReturn(); + }, [&contextRef](kj::Exception&& exception) { + contextRef.sendErrorReturn(kj::mv(exception)); + }).eagerlyEvaluate([&](kj::Exception&& exception) { // Handle exceptions that occur in sendReturn()/sendErrorReturn(). taskFailed(kj::mv(exception)); - }).attach(kj::mv(context)) - .exclusiveJoin(kj::mv(cancelPaf.promise)) - .detach([](kj::Exception&&) {}); + }); } } } ClientHook::VoidPromiseAndPipeline startCall( uint64_t interfaceId, uint64_t methodId, - kj::Own&& capability, kj::Own&& context) { - return capability->call(interfaceId, methodId, kj::mv(context)); + kj::Own&& capability, kj::Own&& context, + ClientHook::CallHints hints) { + return capability->call(interfaceId, methodId, kj::mv(context), hints); } kj::Maybe> getMessageTarget(const rpc::MessageTarget::Reader& target) { @@ -2735,13 +3108,14 @@ private: auto promisedAnswer = target.getPromisedAnswer(); kj::Own pipeline; - auto& base = answers[promisedAnswer.getQuestionId()]; - KJ_REQUIRE(base.active, "PromisedAnswer.questionId is not a current question.") { - return nullptr; + KJ_IF_MAYBE(answer, answers.find(promisedAnswer.getQuestionId())) { + if (answer->active) { + KJ_IF_MAYBE(p, answer->pipeline) { + pipeline = p->get()->addRef(); + } + } } - KJ_IF_MAYBE(p, base.pipeline) { - pipeline = p->get()->addRef(); - } else { + if (pipeline.get() == nullptr) { pipeline = newBrokenPipeline(KJ_EXCEPTION(FAILED, "Pipeline call on a request that returned no capabilities or was already closed.")); } @@ -2768,9 +3142,41 @@ private: // pointer into it, so make sure these destructors run later. kj::Array exportsToRelease; KJ_DEFER(releaseExports(exportsToRelease)); - kj::Maybe>> promiseToRelease; + kj::Maybe promiseToRelease; + + QuestionId questionId = ret.getAnswerId(); + if (questions.isHigh(questionId)) { + // We sent hints with this question saying we didn't want a `Return` but we got one anyway. + // We cannot even look up the question on the question table because it's (remotely) possible + // that we already removed it and re-allocated the ID to something else. So, we should ignore + // the `Return`. But we might want to make note to stop using these hints, to protect against + // the (again, remote) possibility of our ID space wrapping around and leading to confusion. + if (ret.getReleaseParamCaps() && sentCapabilitiesInPipelineOnlyCall) { + // Oh no, it appears the peer wants us to release any capabilities in the params, something + // which only a level 0 peer would request (no version of the C++ RPC system has ever done + // this). And it appears we did send capabilities in at least one pipeline-only call + // previously. But we have no record of which capabilities were sent in *this* call, so + // we cannot release them. Log an error about the leak. + // + // This scenario is unlikely to happen in practice, because sendForPipeline() is not useful + // when talking to a peer that doesn't support capability-passing -- they couldn't possibly + // return a capability to pipeline on! So, I'm not going to spend time to find a solution + // for this corner case. We will log an error, though, just in case someone hits this + // somehow. + KJ_LOG(ERROR, + "sendForPipeline() was used when sending an RPC to a peer, the parameters of that " + "RPC included capabilities, but the peer seems to implement Cap'n Proto at level 0, " + "meaning it does not support capability passing (or, at least, it sent a `Return` " + "with `releaseParamCaps = true`). The capabilities that were sent may have been " + "leaked (they won't be dropped until the connection closes)."); + + sentCapabilitiesInPipelineOnlyCall = false; // don't log again + } + gotReturnForHighQuestionId = true; + return; + } - KJ_IF_MAYBE(question, questions.find(ret.getAnswerId())) { + KJ_IF_MAYBE(question, questions.find(questionId)) { KJ_REQUIRE(question->isAwaitingReturn, "Duplicate Return.") { return; } question->isAwaitingReturn = false; @@ -2780,6 +3186,10 @@ private: question->paramExports = nullptr; } + if (ret.getNoFinishNeeded()) { + question->skipFinish = true; + } + KJ_IF_MAYBE(questionRef, question->selfRef) { switch (ret.which()) { case rpc::Return::RESULTS: { @@ -2821,24 +3231,14 @@ private: case rpc::Return::TAKE_FROM_OTHER_QUESTION: KJ_IF_MAYBE(answer, answers.find(ret.getTakeFromOtherQuestion())) { - KJ_IF_MAYBE(response, answer->redirectedResults) { + KJ_IF_MAYBE(response, answer->task.tryGet()) { questionRef->fulfill(kj::mv(*response)); - answer->redirectedResults = nullptr; + answer->task = Answer::Finished(); KJ_IF_MAYBE(context, answer->callContext) { // Send the `Return` message for the call of which we're taking ownership, so // that the peer knows it can now tear down the call state. context->sendRedirectReturn(); - - // There are three conditions, all of which must be true, before a call is - // canceled: - // 1. The RPC opts in by calling context->allowCancellation(). - // 2. We request cancellation with context->requestCancel(). - // 3. The final response promise -- which we passed to questionRef->fulfill() - // above -- must be dropped. - // - // We would like #3 to imply #2. So... we can just make #2 be true. - context->requestCancel(); } } else { KJ_FAIL_REQUIRE("`Return.takeFromOtherQuestion` referenced a call that did not " @@ -2864,16 +3264,12 @@ private: // Indeed, it does still exist. // Throw away the result promise. - promiseToRelease = kj::mv(answer->redirectedResults); + promiseToRelease = kj::mv(answer->task); KJ_IF_MAYBE(context, answer->callContext) { // Send the `Return` message for the call of which we're taking ownership, so // that the peer knows it can now tear down the call state. context->sendRedirectReturn(); - - // Since the caller has been canceled, make sure the callee that we're tailing to - // gets canceled. - context->requestCancel(); } } } @@ -2896,9 +3292,13 @@ private: KJ_DEFER(releaseExports(exportsToRelease)); Answer answerToRelease; kj::Maybe> pipelineToRelease; + kj::Maybe promiseToRelease; KJ_IF_MAYBE(answer, answers.find(finish.getQuestionId())) { - KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; } + if (!answer->active) { + // Treat the same as if the answer wasn't in the table; see comment below. + return; + } if (finish.getReleaseResultCaps()) { exportsToRelease = kj::mv(answer->resultExports); @@ -2908,15 +3308,54 @@ private: pipelineToRelease = kj::mv(answer->pipeline); - // If the call isn't actually done yet, cancel it. Otherwise, we can go ahead and erase the - // question from the table. KJ_IF_MAYBE(context, answer->callContext) { - context->requestCancel(); + // Destroying answer->task will probably destroy the call context, but we can't prove that + // since it's refcounted. Instead, inform the call context that it is now its job to + // clean up the answer table. Then, cancel the task. + promiseToRelease = kj::mv(answer->task); + answer->task = Answer::Finished(); + context->finish(); } else { + // The call context is already gone so we can tear down the Answer here. answerToRelease = answers.erase(finish.getQuestionId()); } } else { - KJ_FAIL_REQUIRE("'Finish' for invalid question ID.") { return; } + // The `Finish` message targets a qusetion ID that isn't present in our answer table. + // Probably, we send a `Return` with `noFinishNeeded = true`, but the other side didn't + // recognize this hint and sent a `Finish` anyway, or the `Finish` was already in-flight at + // the time we sent the `Return`. We can silently ignore this. + // + // It would be nice to detect invalid finishes somehow, but to do so we would have to + // remember past answer IDs somewhere even when we said `noFinishNeeded`. Assuming the other + // side respects the hint and doesn't send a `Finish`, we'd only be able to clean up these + // records when the other end reuses the question ID, which might never happen. + } + + if (finish.getRequireEarlyCancellationWorkaround()) { + // Defer actual cancellation of the call until the end of the event loop queue. + // + // This is needed for compatibility with older versions of Cap'n Proto (0.10 and prior) in + // which the default was to prohibit cancellation until it was explicitly allowed. In newer + // versions (1.0 and later) cancellation is allowed until explicitly prohibited, that is, if + // we haven't actually delivered the call yet, it can be canceled. This requires less + // bookkeeping and so improved performance. + // + // However, old clients might be inadvertently relying on the old behavior. For example, if + // someone using and old version called `.send()` on a message and then promptly dropped the + // returned Promise, the message would often be delivered. This was not intended to work, but + // did, and could be relied upon by accident. Moreover, the original implementation of + // streaming included a bug where streaming calls *always* sent an immediate Finish. + // + // By deferring cancellation until after a turn of the event loop, we provide an opportunity + // for any `Call` messages we've received to actually be delivered, so that they can opt out + // of cancellation if desired. + KJ_IF_MAYBE(task, promiseToRelease) { + KJ_IF_MAYBE(running, task->tryGet()) { + tasks.add(kj::evalLast([running = kj::mv(*running)]() { + // Just drop `running` here to cancel the call. + })); + } + } } } @@ -3006,26 +3445,40 @@ private: return; } - for (;;) { - KJ_IF_MAYBE(r, target->getResolved()) { - target = r->addRef(); - } else { - break; - } - } + EmbargoId embargoId = context.getSenderLoopback(); - KJ_REQUIRE(target->getBrand() == this, - "'Disembargo' of type 'senderLoopback' sent to an object that does not point " - "back to the sender.") { - return; - } + // It's possible that `target` is a promise capability that hasn't resolved yet, in which + // case we must wait for the resolution. In particular this can happen in the case where + // we have Alice -> Bob -> Carol, Alice makes a call that proxies from Bob to Carol, and + // Carol returns a capability from this call that points all the way back though Bob to + // Alice. When this return capability passes through Bob, Bob will resolve the previous + // promise-pipeline capability to it. However, Bob has to send a Disembargo to Carol before + // completing this resolution. In the meantime, though, Bob returns the final repsonse to + // Alice. Alice then *also* sends a Disembargo to Bob. The Alice -> Bob Disembargo might + // arrive at Bob before the Bob -> Carol Disembargo has resolved, in which case the + // Disembargo is delivered to a promise capability. + auto promise = target->whenResolved() + .then([]() { + // We also need to insert an evalLast() here to make sure that any pending calls towards + // this cap have had time to find their way through the event loop. + return kj::evalLast([]() {}); + }); - EmbargoId embargoId = context.getSenderLoopback(); + tasks.add(promise.then([this, embargoId, target = kj::mv(target)]() mutable { + for (;;) { + KJ_IF_MAYBE(r, target->getResolved()) { + target = r->addRef(); + } else { + break; + } + } + + KJ_REQUIRE(target->getBrand() == this, + "'Disembargo' of type 'senderLoopback' sent to an object that does not point " + "back to the sender.") { + return; + } - // We need to insert an evalLast() here to make sure that any pending calls towards this - // cap have had time to find their way through the event loop. - tasks.add(canceler.wrap(kj::evalLast(kj::mvCapture( - target, [this,embargoId](kj::Own&& target) { if (!connection.is()) { return; } @@ -3045,8 +3498,8 @@ private: // any promise with a direct node in order to solve the Tribble 4-way race condition. // See the documentation of Disembargo in rpc.capnp for more. KJ_REQUIRE(redirect == nullptr, - "'Disembargo' of type 'senderLoopback' sent to an object that does not " - "appear to have been the subject of a previous 'Resolve' message.") { + "'Disembargo' of type 'senderLoopback' sent to an object that does not " + "appear to have been the subject of a previous 'Resolve' message.") { return; } } @@ -3054,7 +3507,7 @@ private: builder.getContext().setReceiverLoopback(embargoId); message->send(); - })))); + })); break; } @@ -3104,7 +3557,7 @@ public: // disassemble it. if (!connections.empty()) { kj::Vector> deleteMe(connections.size()); - kj::Exception shutdownException = KJ_EXCEPTION(FAILED, "RpcSystem was destroyed."); + kj::Exception shutdownException = KJ_EXCEPTION(DISCONNECTED, "RpcSystem was destroyed."); for (auto& entry: connections) { entry.second->disconnect(kj::cp(shutdownException)); deleteMe.add(kj::mv(entry.second)); @@ -3388,4 +3841,20 @@ kj::Own RpcFlowController::newVariableWindowController(Window return kj::heap(getter); } +bool IncomingRpcMessage::isShortLivedRpcMessage(AnyPointer::Reader body) { + switch (body.getAs().which()) { + case rpc::Message::CALL: + case rpc::Message::RETURN: + return false; + default: + return true; + } +} + +kj::Function IncomingRpcMessage::getShortLivedCallback() { + return [](MessageReader& reader) { + return IncomingRpcMessage::isShortLivedRpcMessage(reader.getRoot()); + }; +} + } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp index 50aa496369b..0e718d5bf22 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp @@ -316,7 +316,7 @@ struct Bootstrap { # A Vat may export multiple bootstrap interfaces. In this case, `deprecatedObjectId` specifies # which one to return. If this pointer is null, then the default bootstrap interface is returned. # - # As of verison 0.5, use of this field is deprecated. If a service wants to export multiple + # As of version 0.5, use of this field is deprecated. If a service wants to export multiple # bootstrap interfaces, it should instead define a single bootstrap interface that has methods # that return each of the other interfaces. # @@ -352,7 +352,7 @@ struct Bootstrap { # - Overloading "Restore" also had a security problem: Often, "main" or "well-known" # capabilities exported by a vat are in fact not public: they are intended to be accessed only # by clients who are capable of forming a connection to the vat. This can lead to trouble if - # the client itself has other clients and wishes to foward some `Restore` requests from those + # the client itself has other clients and wishes to forward some `Restore` requests from those # external clients -- it has to be very careful not to allow through `Restore` requests # addressing the default capability. # @@ -415,6 +415,30 @@ struct Call { # `acceptFromThirdParty`. Level 3 implementations should set this true. Otherwise, the callee # will have to proxy the return in the case of a tail call to a third-party vat. + noPromisePipelining @9 :Bool = false; + # If true, the sender promises that it won't make any promise-pipelined calls on the results of + # this call. If it breaks this promise, the receiver may throw an arbitrary error from such + # calls. + # + # The receiver may use this as an optimization, by skipping the bookkeeping needed for pipelining + # when no pipelined calls are expected. The sender typically sets this to false when the method's + # schema does not specify any return capabilities. + + onlyPromisePipeline @10 :Bool = false; + # If true, the sender only plans to use this call to make pipelined calls. The receiver need not + # send a `Return` message (but is still allowed to do so). + # + # Since the sender does not know whether a `Return` will be sent, it must release all state + # related to the call when it sends `Finish`. However, in the case that the callee does not + # recognize this hint and chooses to send a `Return`, then technically the caller is not allowed + # to reuse the question ID until it receives said `Return`. This creates a conundrum: How does + # the caller decide when it's OK to reuse the ID? To sidestep the problem, the C++ implementation + # uses high-numbered IDs (with the high-order bit set) for such calls, and cycles through the + # IDs in order. If all 2^31 IDs in this space are used without ever seeing a `Return`, then the + # implementation assumes that the other end is in fact honoring the hint, and the ID counter is + # allowed to loop around. If a `Return` is ever seen when `onlyPromisePipeline` was set, then + # the implementation stops using this hint. + params @4 :Payload; # The call parameters. `params.content` is a struct whose fields correspond to the parameters of # the method. @@ -496,6 +520,13 @@ struct Return { # The receiver should act as if the sender had sent a release message with count=1 for each # CapDescriptor in the original Call message. + noFinishNeeded @8 :Bool = false; + # If true, the sender does not need the receiver to send a `Finish` message; its answer table + # entry has already been cleaned up. This implies that the results do not contain any + # capabilities, since the `Finish` message would normally release those capabilities from + # promise pipelining responsibility. The caller may still send a `Finish` message if it wants, + # which will be silently ignored by the callee. + union { results @2 :Payload; # The result. @@ -564,6 +595,20 @@ struct Finish { # should always set this true. This defaults true because if level 0 implementations forget to # set it they'll never notice (just silently leak caps), but if level >=1 implementations forget # set it false they'll quickly get errors. + + requireEarlyCancellationWorkaround @2 :Bool = true; + # If true, if the RPC system receives this Finish message before the original call has even been + # delivered, it should defer cancellation util after delivery. In particular, this gives the + # destination object a chance to opt out of cancellation, e.g. as controlled by the + # `allowCancellation` annotation defined in `c++.capnp`. + # + # This is a work-around. Versions 1.0 and up of Cap'n Proto always set this to false. However, + # older versions of Cap'n Proto unintentionally exhibited this errant behavior by default, and + # as a result programs built with older versions could be inadvertently relying on their peers + # to implement the behavior. The purpose of this flag is to let newer versions know when the + # peer is an older version, so that it can attempt to work around the issue. + # + # See also comments in handleFinish() in rpc.c++ for more details. } # Level 1 message types ---------------------------------------------- @@ -707,6 +752,10 @@ struct Disembargo { # is expected that people sending messages to P will shortly start sending them to R instead and # drop P. P is at end-of-life anyway, so it doesn't matter if it ignores chances to further # optimize its path. + # + # Note well: the Tribble 4-way race condition does not require each vat to be *distinct*; as long + # as each resolution crosses a network boundary the race can occur -- so this concerns even level + # 1 implementations, not just level 3 implementations. target @0 :MessageTarget; # What is to be disembargoed. diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp.c++ index 553e72d6991..4927995f474 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp.c++ @@ -259,7 +259,7 @@ static const uint16_t m_91b79f1f808db032[] = {1, 11, 8, 2, 13, 4, 12, 9, 7, 10, static const uint16_t i_91b79f1f808db032[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}; const ::capnp::_::RawSchema s_91b79f1f808db032 = { 0x91b79f1f808db032, b_91b79f1f808db032.words, 232, d_91b79f1f808db032, m_91b79f1f808db032, - 12, 14, i_91b79f1f808db032, nullptr, nullptr, { &s_91b79f1f808db032, nullptr, nullptr, 0, 0, nullptr } + 12, 14, i_91b79f1f808db032, nullptr, nullptr, { &s_91b79f1f808db032, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_e94ccf8031176ec4 = { @@ -321,10 +321,10 @@ static const uint16_t m_e94ccf8031176ec4[] = {1, 0}; static const uint16_t i_e94ccf8031176ec4[] = {0, 1}; const ::capnp::_::RawSchema s_e94ccf8031176ec4 = { 0xe94ccf8031176ec4, b_e94ccf8031176ec4.words, 51, nullptr, m_e94ccf8031176ec4, - 0, 2, i_e94ccf8031176ec4, nullptr, nullptr, { &s_e94ccf8031176ec4, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_e94ccf8031176ec4, nullptr, nullptr, { &s_e94ccf8031176ec4, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<121> b_836a53ce789d4cd4 = { +static const ::capnp::_::AlignedData<155> b_836a53ce789d4cd4 = { { 0, 0, 0, 0, 5, 0, 6, 0, 212, 76, 157, 120, 206, 83, 106, 131, 16, 0, 0, 0, 1, 0, 3, 0, @@ -334,63 +334,77 @@ static const ::capnp::_::AlignedData<121> b_836a53ce789d4cd4 = { 21, 0, 0, 0, 170, 0, 0, 0, 29, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 25, 0, 0, 0, 143, 1, 0, 0, + 25, 0, 0, 0, 255, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 114, 112, 99, 46, 99, 97, 112, 110, 112, 58, 67, 97, 108, 108, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 28, 0, 0, 0, 3, 0, 4, 0, + 36, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 181, 0, 0, 0, 90, 0, 0, 0, + 237, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 180, 0, 0, 0, 3, 0, 1, 0, - 192, 0, 0, 0, 2, 0, 1, 0, + 236, 0, 0, 0, 3, 0, 1, 0, + 248, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 189, 0, 0, 0, 58, 0, 0, 0, + 245, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 184, 0, 0, 0, 3, 0, 1, 0, - 196, 0, 0, 0, 2, 0, 1, 0, + 240, 0, 0, 0, 3, 0, 1, 0, + 252, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 193, 0, 0, 0, 98, 0, 0, 0, + 249, 0, 0, 0, 98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 192, 0, 0, 0, 3, 0, 1, 0, - 204, 0, 0, 0, 2, 0, 1, 0, + 248, 0, 0, 0, 3, 0, 1, 0, + 4, 1, 0, 0, 2, 0, 1, 0, 3, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 201, 0, 0, 0, 74, 0, 0, 0, + 1, 1, 0, 0, 74, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 200, 0, 0, 0, 3, 0, 1, 0, - 212, 0, 0, 0, 2, 0, 1, 0, - 5, 0, 0, 0, 1, 0, 0, 0, + 0, 1, 0, 0, 3, 0, 1, 0, + 12, 1, 0, 0, 2, 0, 1, 0, + 7, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 209, 0, 0, 0, 58, 0, 0, 0, + 9, 1, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 204, 0, 0, 0, 3, 0, 1, 0, - 216, 0, 0, 0, 2, 0, 1, 0, - 6, 0, 0, 0, 0, 0, 0, 0, + 4, 1, 0, 0, 3, 0, 1, 0, + 16, 1, 0, 0, 2, 0, 1, 0, + 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 153, 95, 171, 26, 246, 176, 232, 218, - 213, 0, 0, 0, 114, 0, 0, 0, + 13, 1, 0, 0, 114, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 128, 0, 0, 0, 0, 0, 1, 0, 8, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, - 193, 0, 0, 0, 194, 0, 0, 0, + 249, 0, 0, 0, 194, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 196, 0, 0, 0, 3, 0, 1, 0, - 208, 0, 0, 0, 2, 0, 1, 0, + 252, 0, 0, 0, 3, 0, 1, 0, + 8, 1, 0, 0, 2, 0, 1, 0, + 5, 0, 0, 0, 129, 0, 0, 0, + 0, 0, 1, 0, 9, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 5, 1, 0, 0, 162, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 8, 1, 0, 0, 3, 0, 1, 0, + 20, 1, 0, 0, 2, 0, 1, 0, + 6, 0, 0, 0, 130, 0, 0, 0, + 0, 0, 1, 0, 10, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 17, 1, 0, 0, 162, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 20, 1, 0, 0, 3, 0, 1, 0, + 32, 1, 0, 0, 2, 0, 1, 0, 113, 117, 101, 115, 116, 105, 111, 110, 73, 100, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, @@ -439,6 +453,26 @@ static const ::capnp::_::AlignedData<121> b_836a53ce789d4cd4 = { 97, 108, 108, 111, 119, 84, 104, 105, 114, 100, 80, 97, 114, 116, 121, 84, 97, 105, 108, 67, 97, 108, 108, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 110, 111, 80, 114, 111, 109, 105, 115, + 101, 80, 105, 112, 101, 108, 105, 110, + 105, 110, 103, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 111, 110, 108, 121, 80, 114, 111, 109, + 105, 115, 101, 80, 105, 112, 101, 108, + 105, 110, 101, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -454,11 +488,11 @@ static const ::capnp::_::RawSchema* const d_836a53ce789d4cd4[] = { &s_9a0e61223d96743b, &s_dae8b0f61aab5f99, }; -static const uint16_t m_836a53ce789d4cd4[] = {6, 2, 3, 4, 0, 5, 1}; -static const uint16_t i_836a53ce789d4cd4[] = {0, 1, 2, 3, 4, 5, 6}; +static const uint16_t m_836a53ce789d4cd4[] = {6, 2, 3, 7, 8, 4, 0, 5, 1}; +static const uint16_t i_836a53ce789d4cd4[] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; const ::capnp::_::RawSchema s_836a53ce789d4cd4 = { - 0x836a53ce789d4cd4, b_836a53ce789d4cd4.words, 121, d_836a53ce789d4cd4, m_836a53ce789d4cd4, - 3, 7, i_836a53ce789d4cd4, nullptr, nullptr, { &s_836a53ce789d4cd4, nullptr, nullptr, 0, 0, nullptr } + 0x836a53ce789d4cd4, b_836a53ce789d4cd4.words, 155, d_836a53ce789d4cd4, m_836a53ce789d4cd4, + 3, 9, i_836a53ce789d4cd4, nullptr, nullptr, { &s_836a53ce789d4cd4, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_dae8b0f61aab5f99 = { @@ -537,10 +571,10 @@ static const uint16_t m_dae8b0f61aab5f99[] = {0, 2, 1}; static const uint16_t i_dae8b0f61aab5f99[] = {0, 1, 2}; const ::capnp::_::RawSchema s_dae8b0f61aab5f99 = { 0xdae8b0f61aab5f99, b_dae8b0f61aab5f99.words, 65, d_dae8b0f61aab5f99, m_dae8b0f61aab5f99, - 1, 3, i_dae8b0f61aab5f99, nullptr, nullptr, { &s_dae8b0f61aab5f99, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_dae8b0f61aab5f99, nullptr, nullptr, { &s_dae8b0f61aab5f99, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<148> b_9e19b28d3db3573a = { +static const ::capnp::_::AlignedData<164> b_9e19b28d3db3573a = { { 0, 0, 0, 0, 5, 0, 6, 0, 58, 87, 179, 61, 141, 178, 25, 158, 16, 0, 0, 0, 1, 0, 2, 0, @@ -550,70 +584,77 @@ static const ::capnp::_::AlignedData<148> b_9e19b28d3db3573a = { 21, 0, 0, 0, 186, 0, 0, 0, 29, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 25, 0, 0, 0, 199, 1, 0, 0, + 25, 0, 0, 0, 255, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 114, 112, 99, 46, 99, 97, 112, 110, 112, 58, 82, 101, 116, 117, 114, 110, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 32, 0, 0, 0, 3, 0, 4, 0, + 36, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 209, 0, 0, 0, 74, 0, 0, 0, + 237, 0, 0, 0, 74, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 208, 0, 0, 0, 3, 0, 1, 0, - 220, 0, 0, 0, 2, 0, 1, 0, + 236, 0, 0, 0, 3, 0, 1, 0, + 248, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 32, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, - 217, 0, 0, 0, 138, 0, 0, 0, + 245, 0, 0, 0, 138, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 220, 0, 0, 0, 3, 0, 1, 0, - 232, 0, 0, 0, 2, 0, 1, 0, - 2, 0, 255, 255, 0, 0, 0, 0, + 248, 0, 0, 0, 3, 0, 1, 0, + 4, 1, 0, 0, 2, 0, 1, 0, + 3, 0, 255, 255, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 229, 0, 0, 0, 66, 0, 0, 0, + 1, 1, 0, 0, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 224, 0, 0, 0, 3, 0, 1, 0, - 236, 0, 0, 0, 2, 0, 1, 0, - 3, 0, 254, 255, 0, 0, 0, 0, + 252, 0, 0, 0, 3, 0, 1, 0, + 8, 1, 0, 0, 2, 0, 1, 0, + 4, 0, 254, 255, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 233, 0, 0, 0, 82, 0, 0, 0, + 5, 1, 0, 0, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 232, 0, 0, 0, 3, 0, 1, 0, - 244, 0, 0, 0, 2, 0, 1, 0, - 4, 0, 253, 255, 0, 0, 0, 0, + 4, 1, 0, 0, 3, 0, 1, 0, + 16, 1, 0, 0, 2, 0, 1, 0, + 5, 0, 253, 255, 0, 0, 0, 0, 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 241, 0, 0, 0, 74, 0, 0, 0, + 13, 1, 0, 0, 74, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 240, 0, 0, 0, 3, 0, 1, 0, - 252, 0, 0, 0, 2, 0, 1, 0, - 5, 0, 252, 255, 0, 0, 0, 0, + 12, 1, 0, 0, 3, 0, 1, 0, + 24, 1, 0, 0, 2, 0, 1, 0, + 6, 0, 252, 255, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 249, 0, 0, 0, 170, 0, 0, 0, + 21, 1, 0, 0, 170, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 252, 0, 0, 0, 3, 0, 1, 0, - 8, 1, 0, 0, 2, 0, 1, 0, - 6, 0, 251, 255, 2, 0, 0, 0, + 24, 1, 0, 0, 3, 0, 1, 0, + 36, 1, 0, 0, 2, 0, 1, 0, + 7, 0, 251, 255, 2, 0, 0, 0, 0, 0, 1, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 5, 1, 0, 0, 178, 0, 0, 0, + 33, 1, 0, 0, 178, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 8, 1, 0, 0, 3, 0, 1, 0, - 20, 1, 0, 0, 2, 0, 1, 0, - 7, 0, 250, 255, 0, 0, 0, 0, + 36, 1, 0, 0, 3, 0, 1, 0, + 48, 1, 0, 0, 2, 0, 1, 0, + 8, 0, 250, 255, 0, 0, 0, 0, 0, 0, 1, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 17, 1, 0, 0, 170, 0, 0, 0, + 45, 1, 0, 0, 170, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 20, 1, 0, 0, 3, 0, 1, 0, - 32, 1, 0, 0, 2, 0, 1, 0, + 48, 1, 0, 0, 3, 0, 1, 0, + 60, 1, 0, 0, 2, 0, 1, 0, + 2, 0, 0, 0, 33, 0, 0, 0, + 0, 0, 1, 0, 8, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 57, 1, 0, 0, 122, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 56, 1, 0, 0, 3, 0, 1, 0, + 68, 1, 0, 0, 2, 0, 1, 0, 97, 110, 115, 119, 101, 114, 73, 100, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, @@ -687,6 +728,15 @@ static const ::capnp::_::AlignedData<148> b_9e19b28d3db3573a = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 110, 111, 70, 105, 110, 105, 115, 104, + 78, 101, 101, 100, 101, 100, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; @@ -696,14 +746,14 @@ static const ::capnp::_::RawSchema* const d_9e19b28d3db3573a[] = { &s_9a0e61223d96743b, &s_d625b7063acf691a, }; -static const uint16_t m_9e19b28d3db3573a[] = {7, 0, 4, 3, 1, 2, 5, 6}; -static const uint16_t i_9e19b28d3db3573a[] = {2, 3, 4, 5, 6, 7, 0, 1}; +static const uint16_t m_9e19b28d3db3573a[] = {7, 0, 4, 3, 8, 1, 2, 5, 6}; +static const uint16_t i_9e19b28d3db3573a[] = {2, 3, 4, 5, 6, 7, 0, 1, 8}; const ::capnp::_::RawSchema s_9e19b28d3db3573a = { - 0x9e19b28d3db3573a, b_9e19b28d3db3573a.words, 148, d_9e19b28d3db3573a, m_9e19b28d3db3573a, - 2, 8, i_9e19b28d3db3573a, nullptr, nullptr, { &s_9e19b28d3db3573a, nullptr, nullptr, 0, 0, nullptr } + 0x9e19b28d3db3573a, b_9e19b28d3db3573a.words, 164, d_9e19b28d3db3573a, m_9e19b28d3db3573a, + 2, 9, i_9e19b28d3db3573a, nullptr, nullptr, { &s_9e19b28d3db3573a, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<50> b_d37d2eb2c2f80e63 = { +static const ::capnp::_::AlignedData<69> b_d37d2eb2c2f80e63 = { { 0, 0, 0, 0, 5, 0, 6, 0, 99, 14, 248, 194, 178, 46, 125, 211, 16, 0, 0, 0, 1, 0, 1, 0, @@ -713,28 +763,35 @@ static const ::capnp::_::AlignedData<50> b_d37d2eb2c2f80e63 = { 21, 0, 0, 0, 186, 0, 0, 0, 29, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 25, 0, 0, 0, 119, 0, 0, 0, + 25, 0, 0, 0, 175, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 114, 112, 99, 46, 99, 97, 112, 110, 112, 58, 70, 105, 110, 105, 115, 104, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 8, 0, 0, 0, 3, 0, 4, 0, + 12, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 41, 0, 0, 0, 90, 0, 0, 0, + 69, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 40, 0, 0, 0, 3, 0, 1, 0, - 52, 0, 0, 0, 2, 0, 1, 0, + 68, 0, 0, 0, 3, 0, 1, 0, + 80, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 32, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, - 49, 0, 0, 0, 146, 0, 0, 0, + 77, 0, 0, 0, 146, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 52, 0, 0, 0, 3, 0, 1, 0, - 64, 0, 0, 0, 2, 0, 1, 0, + 80, 0, 0, 0, 3, 0, 1, 0, + 92, 0, 0, 0, 2, 0, 1, 0, + 2, 0, 0, 0, 33, 0, 0, 0, + 0, 0, 1, 0, 2, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 89, 0, 0, 0, 26, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 100, 0, 0, 0, 3, 0, 1, 0, + 112, 0, 0, 0, 2, 0, 1, 0, 113, 117, 101, 115, 116, 105, 111, 110, 73, 100, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, @@ -747,6 +804,18 @@ static const ::capnp::_::AlignedData<50> b_d37d2eb2c2f80e63 = { 114, 101, 108, 101, 97, 115, 101, 82, 101, 115, 117, 108, 116, 67, 97, 112, 115, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 114, 101, 113, 117, 105, 114, 101, 69, + 97, 114, 108, 121, 67, 97, 110, 99, + 101, 108, 108, 97, 116, 105, 111, 110, + 87, 111, 114, 107, 97, 114, 111, 117, + 110, 100, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -757,11 +826,11 @@ static const ::capnp::_::AlignedData<50> b_d37d2eb2c2f80e63 = { }; ::capnp::word const* const bp_d37d2eb2c2f80e63 = b_d37d2eb2c2f80e63.words; #if !CAPNP_LITE -static const uint16_t m_d37d2eb2c2f80e63[] = {0, 1}; -static const uint16_t i_d37d2eb2c2f80e63[] = {0, 1}; +static const uint16_t m_d37d2eb2c2f80e63[] = {0, 1, 2}; +static const uint16_t i_d37d2eb2c2f80e63[] = {0, 1, 2}; const ::capnp::_::RawSchema s_d37d2eb2c2f80e63 = { - 0xd37d2eb2c2f80e63, b_d37d2eb2c2f80e63.words, 50, nullptr, m_d37d2eb2c2f80e63, - 0, 2, i_d37d2eb2c2f80e63, nullptr, nullptr, { &s_d37d2eb2c2f80e63, nullptr, nullptr, 0, 0, nullptr } + 0xd37d2eb2c2f80e63, b_d37d2eb2c2f80e63.words, 69, nullptr, m_d37d2eb2c2f80e63, + 0, 3, i_d37d2eb2c2f80e63, nullptr, nullptr, { &s_d37d2eb2c2f80e63, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<64> b_bbc29655fa89086e = { @@ -840,7 +909,7 @@ static const uint16_t m_bbc29655fa89086e[] = {1, 2, 0}; static const uint16_t i_bbc29655fa89086e[] = {1, 2, 0}; const ::capnp::_::RawSchema s_bbc29655fa89086e = { 0xbbc29655fa89086e, b_bbc29655fa89086e.words, 64, d_bbc29655fa89086e, m_bbc29655fa89086e, - 2, 3, i_bbc29655fa89086e, nullptr, nullptr, { &s_bbc29655fa89086e, nullptr, nullptr, 0, 0, nullptr } + 2, 3, i_bbc29655fa89086e, nullptr, nullptr, { &s_bbc29655fa89086e, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<48> b_ad1a6c0d7dd07497 = { @@ -899,7 +968,7 @@ static const uint16_t m_ad1a6c0d7dd07497[] = {0, 1}; static const uint16_t i_ad1a6c0d7dd07497[] = {0, 1}; const ::capnp::_::RawSchema s_ad1a6c0d7dd07497 = { 0xad1a6c0d7dd07497, b_ad1a6c0d7dd07497.words, 48, nullptr, m_ad1a6c0d7dd07497, - 0, 2, i_ad1a6c0d7dd07497, nullptr, nullptr, { &s_ad1a6c0d7dd07497, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_ad1a6c0d7dd07497, nullptr, nullptr, { &s_ad1a6c0d7dd07497, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<41> b_f964368b0fbd3711 = { @@ -955,7 +1024,7 @@ static const uint16_t m_f964368b0fbd3711[] = {1, 0}; static const uint16_t i_f964368b0fbd3711[] = {0, 1}; const ::capnp::_::RawSchema s_f964368b0fbd3711 = { 0xf964368b0fbd3711, b_f964368b0fbd3711.words, 41, d_f964368b0fbd3711, m_f964368b0fbd3711, - 2, 2, i_f964368b0fbd3711, nullptr, nullptr, { &s_f964368b0fbd3711, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_f964368b0fbd3711, nullptr, nullptr, { &s_f964368b0fbd3711, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<81> b_d562b4df655bdd4d = { @@ -1050,7 +1119,7 @@ static const uint16_t m_d562b4df655bdd4d[] = {2, 3, 1, 0}; static const uint16_t i_d562b4df655bdd4d[] = {0, 1, 2, 3}; const ::capnp::_::RawSchema s_d562b4df655bdd4d = { 0xd562b4df655bdd4d, b_d562b4df655bdd4d.words, 81, d_d562b4df655bdd4d, m_d562b4df655bdd4d, - 1, 4, i_d562b4df655bdd4d, nullptr, nullptr, { &s_d562b4df655bdd4d, nullptr, nullptr, 0, 0, nullptr } + 1, 4, i_d562b4df655bdd4d, nullptr, nullptr, { &s_d562b4df655bdd4d, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<64> b_9c6a046bfbc1ac5a = { @@ -1128,7 +1197,7 @@ static const uint16_t m_9c6a046bfbc1ac5a[] = {0, 2, 1}; static const uint16_t i_9c6a046bfbc1ac5a[] = {0, 1, 2}; const ::capnp::_::RawSchema s_9c6a046bfbc1ac5a = { 0x9c6a046bfbc1ac5a, b_9c6a046bfbc1ac5a.words, 64, d_9c6a046bfbc1ac5a, m_9c6a046bfbc1ac5a, - 1, 3, i_9c6a046bfbc1ac5a, nullptr, nullptr, { &s_9c6a046bfbc1ac5a, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_9c6a046bfbc1ac5a, nullptr, nullptr, { &s_9c6a046bfbc1ac5a, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<64> b_d4c9b56290554016 = { @@ -1203,7 +1272,7 @@ static const uint16_t m_d4c9b56290554016[] = {2, 1, 0}; static const uint16_t i_d4c9b56290554016[] = {0, 1, 2}; const ::capnp::_::RawSchema s_d4c9b56290554016 = { 0xd4c9b56290554016, b_d4c9b56290554016.words, 64, nullptr, m_d4c9b56290554016, - 0, 3, i_d4c9b56290554016, nullptr, nullptr, { &s_d4c9b56290554016, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_d4c9b56290554016, nullptr, nullptr, { &s_d4c9b56290554016, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<63> b_fbe1980490e001af = { @@ -1280,7 +1349,7 @@ static const uint16_t m_fbe1980490e001af[] = {2, 0, 1}; static const uint16_t i_fbe1980490e001af[] = {0, 1, 2}; const ::capnp::_::RawSchema s_fbe1980490e001af = { 0xfbe1980490e001af, b_fbe1980490e001af.words, 63, d_fbe1980490e001af, m_fbe1980490e001af, - 1, 3, i_fbe1980490e001af, nullptr, nullptr, { &s_fbe1980490e001af, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_fbe1980490e001af, nullptr, nullptr, { &s_fbe1980490e001af, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<50> b_95bc14545813fbc1 = { @@ -1344,7 +1413,7 @@ static const uint16_t m_95bc14545813fbc1[] = {0, 1}; static const uint16_t i_95bc14545813fbc1[] = {0, 1}; const ::capnp::_::RawSchema s_95bc14545813fbc1 = { 0x95bc14545813fbc1, b_95bc14545813fbc1.words, 50, d_95bc14545813fbc1, m_95bc14545813fbc1, - 1, 2, i_95bc14545813fbc1, nullptr, nullptr, { &s_95bc14545813fbc1, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_95bc14545813fbc1, nullptr, nullptr, { &s_95bc14545813fbc1, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<52> b_9a0e61223d96743b = { @@ -1410,7 +1479,7 @@ static const uint16_t m_9a0e61223d96743b[] = {1, 0}; static const uint16_t i_9a0e61223d96743b[] = {0, 1}; const ::capnp::_::RawSchema s_9a0e61223d96743b = { 0x9a0e61223d96743b, b_9a0e61223d96743b.words, 52, d_9a0e61223d96743b, m_9a0e61223d96743b, - 1, 2, i_9a0e61223d96743b, nullptr, nullptr, { &s_9a0e61223d96743b, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_9a0e61223d96743b, nullptr, nullptr, { &s_9a0e61223d96743b, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<130> b_8523ddc40b86b8b0 = { @@ -1555,7 +1624,7 @@ static const uint16_t m_8523ddc40b86b8b0[] = {6, 0, 4, 3, 1, 2, 5}; static const uint16_t i_8523ddc40b86b8b0[] = {0, 1, 2, 3, 4, 5, 6}; const ::capnp::_::RawSchema s_8523ddc40b86b8b0 = { 0x8523ddc40b86b8b0, b_8523ddc40b86b8b0.words, 130, d_8523ddc40b86b8b0, m_8523ddc40b86b8b0, - 2, 7, i_8523ddc40b86b8b0, nullptr, nullptr, { &s_8523ddc40b86b8b0, nullptr, nullptr, 0, 0, nullptr } + 2, 7, i_8523ddc40b86b8b0, nullptr, nullptr, { &s_8523ddc40b86b8b0, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<57> b_d800b1d6cd6f1ca0 = { @@ -1626,7 +1695,7 @@ static const uint16_t m_d800b1d6cd6f1ca0[] = {0, 1}; static const uint16_t i_d800b1d6cd6f1ca0[] = {0, 1}; const ::capnp::_::RawSchema s_d800b1d6cd6f1ca0 = { 0xd800b1d6cd6f1ca0, b_d800b1d6cd6f1ca0.words, 57, d_d800b1d6cd6f1ca0, m_d800b1d6cd6f1ca0, - 1, 2, i_d800b1d6cd6f1ca0, nullptr, nullptr, { &s_d800b1d6cd6f1ca0, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_d800b1d6cd6f1ca0, nullptr, nullptr, { &s_d800b1d6cd6f1ca0, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<50> b_f316944415569081 = { @@ -1687,7 +1756,7 @@ static const uint16_t m_f316944415569081[] = {1, 0}; static const uint16_t i_f316944415569081[] = {0, 1}; const ::capnp::_::RawSchema s_f316944415569081 = { 0xf316944415569081, b_f316944415569081.words, 50, nullptr, m_f316944415569081, - 0, 2, i_f316944415569081, nullptr, nullptr, { &s_f316944415569081, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_f316944415569081, nullptr, nullptr, { &s_f316944415569081, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_d37007fde1f0027d = { @@ -1747,7 +1816,7 @@ static const uint16_t m_d37007fde1f0027d[] = {0, 1}; static const uint16_t i_d37007fde1f0027d[] = {0, 1}; const ::capnp::_::RawSchema s_d37007fde1f0027d = { 0xd37007fde1f0027d, b_d37007fde1f0027d.words, 49, nullptr, m_d37007fde1f0027d, - 0, 2, i_d37007fde1f0027d, nullptr, nullptr, { &s_d37007fde1f0027d, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_d37007fde1f0027d, nullptr, nullptr, { &s_d37007fde1f0027d, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<100> b_d625b7063acf691a = { @@ -1861,7 +1930,7 @@ static const uint16_t m_d625b7063acf691a[] = {2, 1, 0, 4, 3}; static const uint16_t i_d625b7063acf691a[] = {0, 1, 2, 3, 4}; const ::capnp::_::RawSchema s_d625b7063acf691a = { 0xd625b7063acf691a, b_d625b7063acf691a.words, 100, d_d625b7063acf691a, m_d625b7063acf691a, - 1, 5, i_d625b7063acf691a, nullptr, nullptr, { &s_d625b7063acf691a, nullptr, nullptr, 0, 0, nullptr } + 1, 5, i_d625b7063acf691a, nullptr, nullptr, { &s_d625b7063acf691a, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<37> b_b28c96e23f4cbd58 = { @@ -1908,7 +1977,7 @@ static const ::capnp::_::AlignedData<37> b_b28c96e23f4cbd58 = { static const uint16_t m_b28c96e23f4cbd58[] = {2, 0, 1, 3}; const ::capnp::_::RawSchema s_b28c96e23f4cbd58 = { 0xb28c96e23f4cbd58, b_b28c96e23f4cbd58.words, 37, nullptr, m_b28c96e23f4cbd58, - 0, 4, nullptr, nullptr, nullptr, { &s_b28c96e23f4cbd58, nullptr, nullptr, 0, 0, nullptr } + 0, 4, nullptr, nullptr, nullptr, { &s_b28c96e23f4cbd58, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE CAPNP_DEFINE_ENUM(Type_b28c96e23f4cbd58, b28c96e23f4cbd58); @@ -1921,163 +1990,243 @@ namespace capnp { namespace rpc { // Message +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Message::_capnpPrivate::dataWordSize; constexpr uint16_t Message::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Message::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Message::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Bootstrap +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Bootstrap::_capnpPrivate::dataWordSize; constexpr uint16_t Bootstrap::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Bootstrap::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Bootstrap::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Call +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Call::_capnpPrivate::dataWordSize; constexpr uint16_t Call::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Call::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Call::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Call::SendResultsTo +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Call::SendResultsTo::_capnpPrivate::dataWordSize; constexpr uint16_t Call::SendResultsTo::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Call::SendResultsTo::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Call::SendResultsTo::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Return +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Return::_capnpPrivate::dataWordSize; constexpr uint16_t Return::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Return::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Return::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Finish +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Finish::_capnpPrivate::dataWordSize; constexpr uint16_t Finish::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Finish::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Finish::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Resolve +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Resolve::_capnpPrivate::dataWordSize; constexpr uint16_t Resolve::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Resolve::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Resolve::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Release +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Release::_capnpPrivate::dataWordSize; constexpr uint16_t Release::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Release::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Release::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Disembargo +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Disembargo::_capnpPrivate::dataWordSize; constexpr uint16_t Disembargo::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Disembargo::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Disembargo::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Disembargo::Context +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Disembargo::Context::_capnpPrivate::dataWordSize; constexpr uint16_t Disembargo::Context::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Disembargo::Context::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Disembargo::Context::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Provide +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Provide::_capnpPrivate::dataWordSize; constexpr uint16_t Provide::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Provide::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Provide::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Accept +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Accept::_capnpPrivate::dataWordSize; constexpr uint16_t Accept::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Accept::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Accept::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Join +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Join::_capnpPrivate::dataWordSize; constexpr uint16_t Join::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Join::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Join::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // MessageTarget +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t MessageTarget::_capnpPrivate::dataWordSize; constexpr uint16_t MessageTarget::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind MessageTarget::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* MessageTarget::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Payload +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Payload::_capnpPrivate::dataWordSize; constexpr uint16_t Payload::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Payload::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Payload::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CapDescriptor +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CapDescriptor::_capnpPrivate::dataWordSize; constexpr uint16_t CapDescriptor::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CapDescriptor::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CapDescriptor::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // PromisedAnswer +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t PromisedAnswer::_capnpPrivate::dataWordSize; constexpr uint16_t PromisedAnswer::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind PromisedAnswer::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* PromisedAnswer::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // PromisedAnswer::Op +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t PromisedAnswer::Op::_capnpPrivate::dataWordSize; constexpr uint16_t PromisedAnswer::Op::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind PromisedAnswer::Op::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* PromisedAnswer::Op::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // ThirdPartyCapDescriptor +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t ThirdPartyCapDescriptor::_capnpPrivate::dataWordSize; constexpr uint16_t ThirdPartyCapDescriptor::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind ThirdPartyCapDescriptor::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* ThirdPartyCapDescriptor::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Exception +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Exception::_capnpPrivate::dataWordSize; constexpr uint16_t Exception::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Exception::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Exception::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp.h index 58eb6a2dd38..ead290923f8 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.capnp.h @@ -6,7 +6,9 @@ #include #include -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif @@ -764,6 +766,10 @@ class Call::Reader { inline bool getAllowThirdPartyTailCall() const; + inline bool getNoPromisePipelining() const; + + inline bool getOnlyPromisePipeline() const; + private: ::capnp::_::StructReader _reader; template @@ -821,6 +827,12 @@ class Call::Builder { inline bool getAllowThirdPartyTailCall(); inline void setAllowThirdPartyTailCall(bool value); + inline bool getNoPromisePipelining(); + inline void setNoPromisePipelining(bool value); + + inline bool getOnlyPromisePipeline(); + inline void setOnlyPromisePipeline(bool value); + private: ::capnp::_::StructBuilder _builder; template @@ -989,6 +1001,8 @@ class Return::Reader { inline bool hasAcceptFromThirdParty() const; inline ::capnp::AnyPointer::Reader getAcceptFromThirdParty() const; + inline bool getNoFinishNeeded() const; + private: ::capnp::_::StructReader _reader; template @@ -1057,6 +1071,9 @@ class Return::Builder { inline ::capnp::AnyPointer::Builder getAcceptFromThirdParty(); inline ::capnp::AnyPointer::Builder initAcceptFromThirdParty(); + inline bool getNoFinishNeeded(); + inline void setNoFinishNeeded(bool value); + private: ::capnp::_::StructBuilder _builder; template @@ -1104,6 +1121,8 @@ class Finish::Reader { inline bool getReleaseResultCaps() const; + inline bool getRequireEarlyCancellationWorkaround() const; + private: ::capnp::_::StructReader _reader; template @@ -1138,6 +1157,9 @@ class Finish::Builder { inline bool getReleaseResultCaps(); inline void setReleaseResultCaps(bool value); + inline bool getRequireEarlyCancellationWorkaround(); + inline void setRequireEarlyCancellationWorkaround(bool value); + private: ::capnp::_::StructBuilder _builder; template @@ -3404,6 +3426,34 @@ inline void Call::Builder::setAllowThirdPartyTailCall(bool value) { ::capnp::bounded<128>() * ::capnp::ELEMENTS, value); } +inline bool Call::Reader::getNoPromisePipelining() const { + return _reader.getDataField( + ::capnp::bounded<129>() * ::capnp::ELEMENTS); +} + +inline bool Call::Builder::getNoPromisePipelining() { + return _builder.getDataField( + ::capnp::bounded<129>() * ::capnp::ELEMENTS); +} +inline void Call::Builder::setNoPromisePipelining(bool value) { + _builder.setDataField( + ::capnp::bounded<129>() * ::capnp::ELEMENTS, value); +} + +inline bool Call::Reader::getOnlyPromisePipeline() const { + return _reader.getDataField( + ::capnp::bounded<130>() * ::capnp::ELEMENTS); +} + +inline bool Call::Builder::getOnlyPromisePipeline() { + return _builder.getDataField( + ::capnp::bounded<130>() * ::capnp::ELEMENTS); +} +inline void Call::Builder::setOnlyPromisePipeline(bool value) { + _builder.setDataField( + ::capnp::bounded<130>() * ::capnp::ELEMENTS, value); +} + inline ::capnp::rpc::Call::SendResultsTo::Which Call::SendResultsTo::Reader::which() const { return _reader.getDataField( ::capnp::bounded<3>() * ::capnp::ELEMENTS); @@ -3762,6 +3812,20 @@ inline ::capnp::AnyPointer::Builder Return::Builder::initAcceptFromThirdParty() return result; } +inline bool Return::Reader::getNoFinishNeeded() const { + return _reader.getDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS); +} + +inline bool Return::Builder::getNoFinishNeeded() { + return _builder.getDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS); +} +inline void Return::Builder::setNoFinishNeeded(bool value) { + _builder.setDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS, value); +} + inline ::uint32_t Finish::Reader::getQuestionId() const { return _reader.getDataField< ::uint32_t>( ::capnp::bounded<0>() * ::capnp::ELEMENTS); @@ -3790,6 +3854,20 @@ inline void Finish::Builder::setReleaseResultCaps(bool value) { ::capnp::bounded<32>() * ::capnp::ELEMENTS, value, true); } +inline bool Finish::Reader::getRequireEarlyCancellationWorkaround() const { + return _reader.getDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS, true); +} + +inline bool Finish::Builder::getRequireEarlyCancellationWorkaround() { + return _builder.getDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS, true); +} +inline void Finish::Builder::setRequireEarlyCancellationWorkaround(bool value) { + _builder.setDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS, value, true); +} + inline ::capnp::rpc::Resolve::Which Resolve::Reader::which() const { return _reader.getDataField( ::capnp::bounded<2>() * ::capnp::ELEMENTS); diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.h b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.h index 8a0dede4985..c4df2f04e2e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/rpc.h @@ -21,7 +21,7 @@ #pragma once -#include "capability.h" +#include #include "rpc-prelude.h" CAPNP_BEGIN_HEADER @@ -36,6 +36,8 @@ class VatNetwork; template class SturdyRefRestorer; +class MessageReader; + template class BootstrapFactory: public _::BootstrapFactoryBase { // Interface that constructs per-client bootstrap interfaces. Use this if you want each client @@ -293,6 +295,16 @@ class IncomingRpcMessage { // Get the total size of the message, for flow control purposes. Although the caller could // also call getBody().targetSize(), doing that would walk the message tree, whereas typical // implementations can compute the size more cheaply by summing segment sizes. + + static bool isShortLivedRpcMessage(AnyPointer::Reader body); + // Helper function which computes whether the standard RpcSystem implementation would consider + // the given message body to be short-lived, meaning it will be dropped before the next message + // is read. This is useful to implement BufferedMessageStream::IsShortLivedCallback. + + static kj::Function getShortLivedCallback(); + // Returns a function that wraps isShortLivedRpcMessage(). The returned function type matches + // `BufferedMessageStream::IsShortLivedCallback` (defined in serialize-async.h), but we don't + // include that header here. }; class RpcFlowController { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader-test.c++ index 8e7c2d749c9..c2b2651bf96 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader-test.c++ @@ -400,6 +400,36 @@ TEST(SchemaLoader, LoadStreaming) { KJ_EXPECT(results.getShortDisplayName() == "StreamResult", results.getShortDisplayName()); } +KJ_TEST("SchemaLoader placeholders are assumed to have caps") { + // Load TestCycle*NoCaps, but don't load its dependency TestAllTypes, so the loader has to assume + // there may be caps. + { + SchemaLoader loader; + Schema schemaA = loader.load(Schema::from().getProto()); + Schema schemaB = loader.load(Schema::from().getProto()); + loader.computeOptimizationHints(); + + KJ_EXPECT(schemaA.asStruct().mayContainCapabilities()); + KJ_EXPECT(schemaB.asStruct().mayContainCapabilities()); + } + + // Try again, but actually load TestAllTypes. Now we recognize there's no caps. + { + SchemaLoader loader; + Schema schemaA = loader.load(Schema::from().getProto()); + Schema schemaB = loader.load(Schema::from().getProto()); + loader.load(Schema::from().getProto()); + loader.computeOptimizationHints(); + + KJ_EXPECT(!schemaA.asStruct().mayContainCapabilities()); + KJ_EXPECT(!schemaB.asStruct().mayContainCapabilities()); + } + + // NOTE: computeOptimizationHints() is also tested in `schema-test.c++` where we test that + // various compiled types have the correct hints, which relies on the code generator having + // computed the hints. +} + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader.c++ index 154f37de0ca..7c056c50417 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader.c++ @@ -114,6 +114,8 @@ public: kj::Array getAllLoaded() const; + void computeOptimizationHints(); + void requireStructSize(uint64_t id, uint dataWordCount, uint pointerCount); // Require any struct nodes loaded with this ID -- in the past and in the future -- to have at // least the given sizes. Struct nodes that don't comply will simply be rewritten to comply. @@ -1827,6 +1829,175 @@ kj::Array SchemaLoader::Impl::getAllLoaded() const { return result; } +void SchemaLoader::Impl::computeOptimizationHints() { + kj::HashMap<_::RawSchema*, kj::Vector<_::RawSchema*>> undecided; + // This map contains schemas for which we haven't yet decided if they might have capabilities. + // They at least do not directly contain capabilities, but they can't be fully decided until + // the dependents are decided. + // + // Each entry maps to a list of other schemas whose decisions depend on this schema. When a + // schema in the map is discovered to contain capabilities, then all these dependents must also + // be presumed to contain capabilities. + + // First pass: Decide on the easy cases and populate the `undecided` map with hard cases. + for (auto& entry: schemas) { + _::RawSchema* schema = entry.value; + + // Default to assuming everything could contain caps. + schema->mayContainCapabilities = true; + + if (schema->lazyInitializer != nullptr) { + // Not initialized yet, so we have to be conservative and assume there could be capabilities. + continue; + } + + auto node = readMessageUnchecked(schema->encodedNode); + + if (!node.isStruct()) { + // Non-structs are irrelevant. + continue; + } + + auto structSchema = node.getStruct(); + + bool foundAnyCaps = false; + bool foundAnyStructs = false; + for (auto field: structSchema.getFields()) { + switch (field.which()) { + case schema::Field::GROUP: + foundAnyStructs = true; + break; + case schema::Field::SLOT: { + auto type = field.getSlot().getType(); + while (type.isList()) { + type = type.getList().getElementType(); + } + + switch (type.which()) { + case schema::Type::VOID: + case schema::Type::BOOL: + case schema::Type::INT8: + case schema::Type::INT16: + case schema::Type::INT32: + case schema::Type::INT64: + case schema::Type::UINT8: + case schema::Type::UINT16: + case schema::Type::UINT32: + case schema::Type::UINT64: + case schema::Type::FLOAT32: + case schema::Type::FLOAT64: + case schema::Type::TEXT: + case schema::Type::DATA: + case schema::Type::ENUM: + // Not a capability. + break; + + case schema::Type::STRUCT: + foundAnyStructs = true; + break; + + case schema::Type::ANY_POINTER: // could be a capability, or transitively contain one + case schema::Type::INTERFACE: // definitely a capability + foundAnyCaps = true; + break; + + case schema::Type::LIST: + KJ_UNREACHABLE; // handled above + } + break; + } + } + + if (foundAnyCaps) break; // no point continuing + } + + if (foundAnyCaps) { + // Definitely has capabilities, don't add to `undecided`. + } else if (!foundAnyStructs) { + // Definitely does NOT have capabilities. Go ahead and set the hint and don't add to + // `undecided`. + schema->mayContainCapabilities = false; + } else { + // Don't know yet. Mark as no-capabilities for now, but place in `undecided` set to review + // later. + schema->mayContainCapabilities = false; + undecided.insert(schema, {}); + } + } + + // Second pass: For all undecided schemas, check dependencies and register as dependents where + // needed. + kj::Vector<_::RawSchema*> decisions; // Schemas that have become decided. + for (auto& entry: undecided) { + auto schema = entry.key; + + auto node = readMessageUnchecked(schema->encodedNode).getStruct(); + + for (auto field: node.getFields()) { + kj::Maybe depId; + + switch (field.which()) { + case schema::Field::GROUP: + depId = field.getGroup().getTypeId(); + break; + case schema::Field::SLOT: { + auto type = field.getSlot().getType(); + while (type.isList()) { + type = type.getList().getElementType(); + } + if (type.isStruct()) { + depId = type.getStruct().getTypeId(); + } + break; + } + } + + KJ_IF_MAYBE(d, depId) { + _::RawSchema* dep = KJ_ASSERT_NONNULL(schemas.find(*d)); + + if (dep->mayContainCapabilities) { + // Oops, this dependency is already known to have capabilities. So that means the current + // schema also has capabilities, transitively. Mark it as such. + schema->mayContainCapabilities = true; + + // Schedule this schema for removal later. + decisions.add(schema); + + // Might as well end the loop early. + break; + } else KJ_IF_MAYBE(undecidedEntry, undecided.find(dep)) { + // This dependency is in the undecided set. Register interest in it. + undecidedEntry->add(schema); + } else { + // This dependency is decided, and the decision is that it has no capabilities. So it + // has no impact on the dependent. + } + } + } + } + + // Third pass: For each decision we made, remove it and propagate to its dependents. + while (!decisions.empty()) { + _::RawSchema* decision = decisions.back(); + decisions.removeLast(); + + auto& entry = KJ_ASSERT_NONNULL(undecided.findEntry(decision)); + for (auto& dependent: entry.value) { + if (!dependent->mayContainCapabilities) { + // The dependent was not previously decided. But, we now know it has a dependency which has + // capabilities, therefore we can decide the dependent. + dependent->mayContainCapabilities = true; + decisions.add(dependent); + } + } + undecided.erase(entry); + } + + // Everything that is left in `undecided` must only be waiting on other undecided schemas. We + // can therefore decide that none of them have any capabilities. We marked them as such + // earlier so now we're all done. +} + void SchemaLoader::Impl::requireStructSize(uint64_t id, uint dataWordCount, uint pointerCount) { structSizeRequirements.upsert(id, { uint16_t(dataWordCount), uint16_t(pointerCount) }, [&](RequiredSize& existingValue, RequiredSize&& newValue) { @@ -2086,6 +2257,10 @@ kj::Array SchemaLoader::getAllLoaded() const { return impl.lockShared()->get()->getAllLoaded(); } +void SchemaLoader::computeOptimizationHints() { + impl.lockExclusive()->get()->computeOptimizationHints(); +} + void SchemaLoader::loadNative(const _::RawSchema* nativeSchema) { impl.lockExclusive()->get()->loadNative(nativeSchema); } diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader.h b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader.h index 90533158eb9..5db8364c428 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-loader.h @@ -64,7 +64,7 @@ class SchemaLoader { // that isn't already loaded. ~SchemaLoader() noexcept(false); - KJ_DISALLOW_COPY(SchemaLoader); + KJ_DISALLOW_COPY_AND_MOVE(SchemaLoader); Schema get(uint64_t id, schema::Brand::Reader brand = schema::Brand::Reader(), Schema scope = Schema()) const; @@ -149,6 +149,19 @@ class SchemaLoader { // loadCompiledTypeAndDependencies() in order to get a flat list of all of T's transitive // dependencies. + void computeOptimizationHints(); + // Call after all interesting schemas have been loaded to compute optimization hints. In + // particular, this initializes `hasNoCapabilities` for every struct type. Before this is called, + // that value is initialized to false for all types (which ensures correct behavior but does not + // allow the optimization). + // + // If any loaded struct types contain fields of types for which no schema has been loaded, they + // will be presumed to possibly contain capabilities. `LazyLoadCallback` will NOT be invoked to + // load any types that haven't been loaded yet. + // + // TODO(someday): Perhaps we could dynamically initialize the hints on-demand, but it would be + // much more work to implement. + private: class Validator; class CompatibilityChecker; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser-test.c++ index c9435943ffd..5bc618859c7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser-test.c++ @@ -272,5 +272,27 @@ TEST(SchemaParser, SourceInfo) { expectSourceInfo(thud.getSourceInfo(), 0xcca9972702b730b4, "post-comment\n", {}); } +TEST(SchemaParser, SetFileIdsRequired) { + FakeFileReader reader; + reader.add("no-file-id.capnp", + "const foo :Int32 = 123;\n"); + + { + SchemaParser parser; + parser.setDiskFilesystem(reader); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("File does not declare an ID.", + parser.parseDiskFile("no-file-id.capnp", "no-file-id.capnp", nullptr)); + } + { + SchemaParser parser; + parser.setDiskFilesystem(reader); + parser.setFileIdsRequired(false); + + auto fileSchema = parser.parseDiskFile("no-file-id.capnp", "no-file-id.capnp", nullptr); + KJ_EXPECT(fileSchema.getNested("foo").asConst().as() == 123); + } +} + } // namespace } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser.c++ index 909be25cc8b..b5aeec12dae 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser.c++ @@ -88,7 +88,7 @@ public: compiler::lex(content, statements, *this); auto parsed = orphanage.newOrphan(); - compiler::parseFile(statements.getStatements(), parsed.get(), *this); + compiler::parseFile(statements.getStatements(), parsed.get(), *this, parser.fileIdsRequired); return parsed; } @@ -148,7 +148,7 @@ private: namespace { struct SchemaFileHash { - inline bool operator()(const SchemaFile* f) const { + inline size_t operator()(const SchemaFile* f) const { return f->hashCode(); } }; @@ -313,6 +313,10 @@ SchemaLoader& SchemaParser::getLoader() { return impl->compiler.getLoader(); } +const SchemaLoader& SchemaParser::getLoader() const { + return impl->compiler.getLoader(); +} + kj::Maybe ParsedSchema::findNested(kj::StringPtr name) const { // TODO(someday): lookup() doesn't handle generics correctly. Use the ModuleScope/CompiledType // interface instead. We can also add an applybrand() method to ParsedSchema using those diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser.h b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser.h index 283b8cea638..6c48763771f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-parser.h @@ -141,14 +141,35 @@ class SchemaParser { getLoader().loadCompiledTypeAndDependencies(); } + kj::Array getAllLoaded() const { + // Gets an array of all schema nodes that have been parsed so far. + return getLoader().getAllLoaded(); + } + + void setFileIdsRequired(bool value) { fileIdsRequired = value; } + // By befault, capnp files must declare a file-level type ID (like `@0xbe702824338d3f7f;`). + // Use `setFileIdsReqired(false)` to lift this requirement. + // + // If no ID is specified, a random one will be assigned. This will cause all types declared in + // the file to have randomized IDs as well (unless they declare an ID explicitly), which means + // that parsing the same file twice will appear to to produce a totally new, incompatible set of + // types. In particular, this means that you will not be able to use any interface types in the + // file for RPC, since the RPC protocol uses type IDs to identify methods. + // + // Setting this false is particularly useful when using Cap'n Proto as a config format. Typically + // type IDs are irrelevant for config files, and the requirement to specify one is cumbersome. + // For this reason, `capnp eval` does not require type ID to be present. + private: struct Impl; struct DiskFileCompat; class ModuleImpl; kj::Own impl; mutable bool hadErrors = false; + bool fileIdsRequired = true; ModuleImpl& getModuleImpl(kj::Own&& file) const; + const SchemaLoader& getLoader() const; SchemaLoader& getLoader(); friend class ParsedSchema; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-test.c++ index f33152e2c52..94b3c9a470e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema-test.c++ @@ -371,6 +371,43 @@ TEST(Schema, Generics) { } } +KJ_TEST("StructSchema::hasNoCapabilites()") { + // At present, TestAllTypes doesn't actually cover interfaces or AnyPointer. + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + KJ_EXPECT(Schema::from().mayContainCapabilities()); + KJ_EXPECT(Schema::from().mayContainCapabilities()); + + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + // Generic arguments could be capabilities. + KJ_EXPECT(Schema::from::Inner>().mayContainCapabilities()); + + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + KJ_EXPECT(Schema::from().mayContainCapabilities()); + KJ_EXPECT(Schema::from().mayContainCapabilities()); +} + +KJ_TEST("list-of-enum as generic type parameter has working schema") { + // Tests for a bug where when a list-of-enum type was used as a type parameter to a generic, + // the schema would be constructed wrong. + auto field = Schema::from() + .getFieldByName("bindEnumList").getType().asStruct() + .getFieldByName("foo"); + auto type = field.getType(); + KJ_ASSERT(type.isList()); + auto elementType = type.asList().getElementType(); + KJ_ASSERT(elementType.isEnum()); + KJ_ASSERT(elementType.asEnum() == Schema::from()); +} + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema.c++ index 8df65313a62..9ec5df62eae 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema.c++ @@ -265,6 +265,19 @@ Schema::BrandArgumentList Schema::getBrandArgumentsAtScope(uint64_t scopeId) con return BrandArgumentList(scopeId, raw->isUnbound()); } +kj::Array Schema::getGenericScopeIds() const { + if (!getProto().getIsGeneric()) + return nullptr; + + auto result = kj::heapArray(raw->scopeCount); + for (auto iScope: kj::indices(result)) { + result[iScope] = raw->scopes[iScope].typeId; + } + + return result; +} + + StructSchema Schema::asStruct() const { KJ_REQUIRE(getProto().isStruct(), "Tried to use non-struct schema as a struct.", getProto().getDisplayName()) { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema.capnp.c++ index efe699fd067..1b7c7c2ef87 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema.capnp.c++ @@ -248,7 +248,7 @@ static const uint16_t m_e682ab4cf923a417[] = {11, 5, 10, 1, 2, 8, 6, 0, 9, 13, 4 static const uint16_t i_e682ab4cf923a417[] = {6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13}; const ::capnp::_::RawSchema s_e682ab4cf923a417 = { 0xe682ab4cf923a417, b_e682ab4cf923a417.words, 225, d_e682ab4cf923a417, m_e682ab4cf923a417, - 8, 14, i_e682ab4cf923a417, nullptr, nullptr, { &s_e682ab4cf923a417, nullptr, nullptr, 0, 0, nullptr } + 8, 14, i_e682ab4cf923a417, nullptr, nullptr, { &s_e682ab4cf923a417, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<34> b_b9521bccf10fa3b1 = { @@ -293,7 +293,7 @@ static const uint16_t m_b9521bccf10fa3b1[] = {0}; static const uint16_t i_b9521bccf10fa3b1[] = {0}; const ::capnp::_::RawSchema s_b9521bccf10fa3b1 = { 0xb9521bccf10fa3b1, b_b9521bccf10fa3b1.words, 34, nullptr, m_b9521bccf10fa3b1, - 0, 1, i_b9521bccf10fa3b1, nullptr, nullptr, { &s_b9521bccf10fa3b1, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_b9521bccf10fa3b1, nullptr, nullptr, { &s_b9521bccf10fa3b1, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_debf55bbfa0fc242 = { @@ -353,7 +353,7 @@ static const uint16_t m_debf55bbfa0fc242[] = {1, 0}; static const uint16_t i_debf55bbfa0fc242[] = {0, 1}; const ::capnp::_::RawSchema s_debf55bbfa0fc242 = { 0xdebf55bbfa0fc242, b_debf55bbfa0fc242.words, 49, nullptr, m_debf55bbfa0fc242, - 0, 2, i_debf55bbfa0fc242, nullptr, nullptr, { &s_debf55bbfa0fc242, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_debf55bbfa0fc242, nullptr, nullptr, { &s_debf55bbfa0fc242, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<72> b_f38e1de3041357ae = { @@ -439,7 +439,7 @@ static const uint16_t m_f38e1de3041357ae[] = {1, 0, 2}; static const uint16_t i_f38e1de3041357ae[] = {0, 1, 2}; const ::capnp::_::RawSchema s_f38e1de3041357ae = { 0xf38e1de3041357ae, b_f38e1de3041357ae.words, 72, d_f38e1de3041357ae, m_f38e1de3041357ae, - 1, 3, i_f38e1de3041357ae, nullptr, nullptr, { &s_f38e1de3041357ae, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_f38e1de3041357ae, nullptr, nullptr, { &s_f38e1de3041357ae, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<36> b_c2ba9038898e1fa2 = { @@ -486,7 +486,7 @@ static const uint16_t m_c2ba9038898e1fa2[] = {0}; static const uint16_t i_c2ba9038898e1fa2[] = {0}; const ::capnp::_::RawSchema s_c2ba9038898e1fa2 = { 0xc2ba9038898e1fa2, b_c2ba9038898e1fa2.words, 36, nullptr, m_c2ba9038898e1fa2, - 0, 1, i_c2ba9038898e1fa2, nullptr, nullptr, { &s_c2ba9038898e1fa2, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_c2ba9038898e1fa2, nullptr, nullptr, { &s_c2ba9038898e1fa2, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<134> b_9ea0b19b37fb4435 = { @@ -636,7 +636,7 @@ static const uint16_t m_9ea0b19b37fb4435[] = {0, 4, 5, 6, 3, 1, 2}; static const uint16_t i_9ea0b19b37fb4435[] = {0, 1, 2, 3, 4, 5, 6}; const ::capnp::_::RawSchema s_9ea0b19b37fb4435 = { 0x9ea0b19b37fb4435, b_9ea0b19b37fb4435.words, 134, d_9ea0b19b37fb4435, m_9ea0b19b37fb4435, - 3, 7, i_9ea0b19b37fb4435, nullptr, nullptr, { &s_9ea0b19b37fb4435, nullptr, nullptr, 0, 0, nullptr } + 3, 7, i_9ea0b19b37fb4435, nullptr, nullptr, { &s_9ea0b19b37fb4435, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<37> b_b54ab3364333f598 = { @@ -688,7 +688,7 @@ static const uint16_t m_b54ab3364333f598[] = {0}; static const uint16_t i_b54ab3364333f598[] = {0}; const ::capnp::_::RawSchema s_b54ab3364333f598 = { 0xb54ab3364333f598, b_b54ab3364333f598.words, 37, d_b54ab3364333f598, m_b54ab3364333f598, - 2, 1, i_b54ab3364333f598, nullptr, nullptr, { &s_b54ab3364333f598, nullptr, nullptr, 0, 0, nullptr } + 2, 1, i_b54ab3364333f598, nullptr, nullptr, { &s_b54ab3364333f598, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<57> b_e82753cff0c2218f = { @@ -761,7 +761,7 @@ static const uint16_t m_e82753cff0c2218f[] = {0, 1}; static const uint16_t i_e82753cff0c2218f[] = {0, 1}; const ::capnp::_::RawSchema s_e82753cff0c2218f = { 0xe82753cff0c2218f, b_e82753cff0c2218f.words, 57, d_e82753cff0c2218f, m_e82753cff0c2218f, - 3, 2, i_e82753cff0c2218f, nullptr, nullptr, { &s_e82753cff0c2218f, nullptr, nullptr, 0, 0, nullptr } + 3, 2, i_e82753cff0c2218f, nullptr, nullptr, { &s_e82753cff0c2218f, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<47> b_b18aa5ac7a0d9420 = { @@ -824,7 +824,7 @@ static const uint16_t m_b18aa5ac7a0d9420[] = {0, 1}; static const uint16_t i_b18aa5ac7a0d9420[] = {0, 1}; const ::capnp::_::RawSchema s_b18aa5ac7a0d9420 = { 0xb18aa5ac7a0d9420, b_b18aa5ac7a0d9420.words, 47, d_b18aa5ac7a0d9420, m_b18aa5ac7a0d9420, - 3, 2, i_b18aa5ac7a0d9420, nullptr, nullptr, { &s_b18aa5ac7a0d9420, nullptr, nullptr, 0, 0, nullptr } + 3, 2, i_b18aa5ac7a0d9420, nullptr, nullptr, { &s_b18aa5ac7a0d9420, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<228> b_ec1619d4400a0290 = { @@ -1067,7 +1067,7 @@ static const uint16_t m_ec1619d4400a0290[] = {12, 2, 3, 4, 6, 1, 8, 9, 10, 11, 5 static const uint16_t i_ec1619d4400a0290[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; const ::capnp::_::RawSchema s_ec1619d4400a0290 = { 0xec1619d4400a0290, b_ec1619d4400a0290.words, 228, d_ec1619d4400a0290, m_ec1619d4400a0290, - 2, 13, i_ec1619d4400a0290, nullptr, nullptr, { &s_ec1619d4400a0290, nullptr, nullptr, 0, 0, nullptr } + 2, 13, i_ec1619d4400a0290, nullptr, nullptr, { &s_ec1619d4400a0290, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<114> b_9aad50a41f4af45f = { @@ -1198,7 +1198,7 @@ static const uint16_t m_9aad50a41f4af45f[] = {2, 1, 3, 5, 0, 6, 4}; static const uint16_t i_9aad50a41f4af45f[] = {4, 5, 0, 1, 2, 3, 6}; const ::capnp::_::RawSchema s_9aad50a41f4af45f = { 0x9aad50a41f4af45f, b_9aad50a41f4af45f.words, 114, d_9aad50a41f4af45f, m_9aad50a41f4af45f, - 4, 7, i_9aad50a41f4af45f, nullptr, nullptr, { &s_9aad50a41f4af45f, nullptr, nullptr, 0, 0, nullptr } + 4, 7, i_9aad50a41f4af45f, nullptr, nullptr, { &s_9aad50a41f4af45f, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<25> b_97b14cbe7cfec712 = { @@ -1232,7 +1232,7 @@ static const ::capnp::_::AlignedData<25> b_97b14cbe7cfec712 = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_97b14cbe7cfec712 = { 0x97b14cbe7cfec712, b_97b14cbe7cfec712.words, 25, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_97b14cbe7cfec712, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_97b14cbe7cfec712, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<80> b_c42305476bb4746f = { @@ -1328,7 +1328,7 @@ static const uint16_t m_c42305476bb4746f[] = {2, 3, 0, 1}; static const uint16_t i_c42305476bb4746f[] = {0, 1, 2, 3}; const ::capnp::_::RawSchema s_c42305476bb4746f = { 0xc42305476bb4746f, b_c42305476bb4746f.words, 80, d_c42305476bb4746f, m_c42305476bb4746f, - 3, 4, i_c42305476bb4746f, nullptr, nullptr, { &s_c42305476bb4746f, nullptr, nullptr, 0, 0, nullptr } + 3, 4, i_c42305476bb4746f, nullptr, nullptr, { &s_c42305476bb4746f, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<32> b_cafccddb68db1d11 = { @@ -1374,7 +1374,7 @@ static const uint16_t m_cafccddb68db1d11[] = {0}; static const uint16_t i_cafccddb68db1d11[] = {0}; const ::capnp::_::RawSchema s_cafccddb68db1d11 = { 0xcafccddb68db1d11, b_cafccddb68db1d11.words, 32, d_cafccddb68db1d11, m_cafccddb68db1d11, - 1, 1, i_cafccddb68db1d11, nullptr, nullptr, { &s_cafccddb68db1d11, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_cafccddb68db1d11, nullptr, nullptr, { &s_cafccddb68db1d11, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<50> b_bb90d5c287870be6 = { @@ -1438,7 +1438,7 @@ static const uint16_t m_bb90d5c287870be6[] = {1, 0}; static const uint16_t i_bb90d5c287870be6[] = {0, 1}; const ::capnp::_::RawSchema s_bb90d5c287870be6 = { 0xbb90d5c287870be6, b_bb90d5c287870be6.words, 50, d_bb90d5c287870be6, m_bb90d5c287870be6, - 1, 2, i_bb90d5c287870be6, nullptr, nullptr, { &s_bb90d5c287870be6, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_bb90d5c287870be6, nullptr, nullptr, { &s_bb90d5c287870be6, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<69> b_978a7cebdc549a4d = { @@ -1521,7 +1521,7 @@ static const uint16_t m_978a7cebdc549a4d[] = {2, 1, 0}; static const uint16_t i_978a7cebdc549a4d[] = {0, 1, 2}; const ::capnp::_::RawSchema s_978a7cebdc549a4d = { 0x978a7cebdc549a4d, b_978a7cebdc549a4d.words, 69, d_978a7cebdc549a4d, m_978a7cebdc549a4d, - 1, 3, i_978a7cebdc549a4d, nullptr, nullptr, { &s_978a7cebdc549a4d, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_978a7cebdc549a4d, nullptr, nullptr, { &s_978a7cebdc549a4d, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<48> b_a9962a9ed0a4d7f8 = { @@ -1583,7 +1583,7 @@ static const uint16_t m_a9962a9ed0a4d7f8[] = {1, 0}; static const uint16_t i_a9962a9ed0a4d7f8[] = {0, 1}; const ::capnp::_::RawSchema s_a9962a9ed0a4d7f8 = { 0xa9962a9ed0a4d7f8, b_a9962a9ed0a4d7f8.words, 48, d_a9962a9ed0a4d7f8, m_a9962a9ed0a4d7f8, - 1, 2, i_a9962a9ed0a4d7f8, nullptr, nullptr, { &s_a9962a9ed0a4d7f8, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_a9962a9ed0a4d7f8, nullptr, nullptr, { &s_a9962a9ed0a4d7f8, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<155> b_9500cce23b334d80 = { @@ -1754,7 +1754,7 @@ static const uint16_t m_9500cce23b334d80[] = {4, 1, 7, 0, 5, 2, 6, 3}; static const uint16_t i_9500cce23b334d80[] = {0, 1, 2, 3, 4, 5, 6, 7}; const ::capnp::_::RawSchema s_9500cce23b334d80 = { 0x9500cce23b334d80, b_9500cce23b334d80.words, 155, d_9500cce23b334d80, m_9500cce23b334d80, - 3, 8, i_9500cce23b334d80, nullptr, nullptr, { &s_9500cce23b334d80, nullptr, nullptr, 0, 0, nullptr } + 3, 8, i_9500cce23b334d80, nullptr, nullptr, { &s_9500cce23b334d80, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<269> b_d07378ede1f9cc60 = { @@ -2041,7 +2041,7 @@ static const uint16_t m_d07378ede1f9cc60[] = {18, 1, 13, 15, 10, 11, 3, 4, 5, 2, static const uint16_t i_d07378ede1f9cc60[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}; const ::capnp::_::RawSchema s_d07378ede1f9cc60 = { 0xd07378ede1f9cc60, b_d07378ede1f9cc60.words, 269, d_d07378ede1f9cc60, m_d07378ede1f9cc60, - 5, 19, i_d07378ede1f9cc60, nullptr, nullptr, { &s_d07378ede1f9cc60, nullptr, nullptr, 0, 0, nullptr } + 5, 19, i_d07378ede1f9cc60, nullptr, nullptr, { &s_d07378ede1f9cc60, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<33> b_87e739250a60ea97 = { @@ -2088,7 +2088,7 @@ static const uint16_t m_87e739250a60ea97[] = {0}; static const uint16_t i_87e739250a60ea97[] = {0}; const ::capnp::_::RawSchema s_87e739250a60ea97 = { 0x87e739250a60ea97, b_87e739250a60ea97.words, 33, d_87e739250a60ea97, m_87e739250a60ea97, - 1, 1, i_87e739250a60ea97, nullptr, nullptr, { &s_87e739250a60ea97, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_87e739250a60ea97, nullptr, nullptr, { &s_87e739250a60ea97, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<47> b_9e0e78711a7f87a9 = { @@ -2150,7 +2150,7 @@ static const uint16_t m_9e0e78711a7f87a9[] = {1, 0}; static const uint16_t i_9e0e78711a7f87a9[] = {0, 1}; const ::capnp::_::RawSchema s_9e0e78711a7f87a9 = { 0x9e0e78711a7f87a9, b_9e0e78711a7f87a9.words, 47, d_9e0e78711a7f87a9, m_9e0e78711a7f87a9, - 2, 2, i_9e0e78711a7f87a9, nullptr, nullptr, { &s_9e0e78711a7f87a9, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_9e0e78711a7f87a9, nullptr, nullptr, { &s_9e0e78711a7f87a9, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<47> b_ac3a6f60ef4cc6d3 = { @@ -2212,7 +2212,7 @@ static const uint16_t m_ac3a6f60ef4cc6d3[] = {1, 0}; static const uint16_t i_ac3a6f60ef4cc6d3[] = {0, 1}; const ::capnp::_::RawSchema s_ac3a6f60ef4cc6d3 = { 0xac3a6f60ef4cc6d3, b_ac3a6f60ef4cc6d3.words, 47, d_ac3a6f60ef4cc6d3, m_ac3a6f60ef4cc6d3, - 2, 2, i_ac3a6f60ef4cc6d3, nullptr, nullptr, { &s_ac3a6f60ef4cc6d3, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_ac3a6f60ef4cc6d3, nullptr, nullptr, { &s_ac3a6f60ef4cc6d3, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<48> b_ed8bca69f7fb0cbf = { @@ -2275,7 +2275,7 @@ static const uint16_t m_ed8bca69f7fb0cbf[] = {1, 0}; static const uint16_t i_ed8bca69f7fb0cbf[] = {0, 1}; const ::capnp::_::RawSchema s_ed8bca69f7fb0cbf = { 0xed8bca69f7fb0cbf, b_ed8bca69f7fb0cbf.words, 48, d_ed8bca69f7fb0cbf, m_ed8bca69f7fb0cbf, - 2, 2, i_ed8bca69f7fb0cbf, nullptr, nullptr, { &s_ed8bca69f7fb0cbf, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_ed8bca69f7fb0cbf, nullptr, nullptr, { &s_ed8bca69f7fb0cbf, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<46> b_c2573fe8a23e49f1 = { @@ -2338,7 +2338,7 @@ static const uint16_t m_c2573fe8a23e49f1[] = {2, 1, 0}; static const uint16_t i_c2573fe8a23e49f1[] = {0, 1, 2}; const ::capnp::_::RawSchema s_c2573fe8a23e49f1 = { 0xc2573fe8a23e49f1, b_c2573fe8a23e49f1.words, 46, d_c2573fe8a23e49f1, m_c2573fe8a23e49f1, - 4, 3, i_c2573fe8a23e49f1, nullptr, nullptr, { &s_c2573fe8a23e49f1, nullptr, nullptr, 0, 0, nullptr } + 4, 3, i_c2573fe8a23e49f1, nullptr, nullptr, { &s_c2573fe8a23e49f1, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<81> b_8e3b5f79fe593656 = { @@ -2433,7 +2433,7 @@ static const uint16_t m_8e3b5f79fe593656[] = {0, 3, 2, 1}; static const uint16_t i_8e3b5f79fe593656[] = {0, 1, 2, 3}; const ::capnp::_::RawSchema s_8e3b5f79fe593656 = { 0x8e3b5f79fe593656, b_8e3b5f79fe593656.words, 81, d_8e3b5f79fe593656, m_8e3b5f79fe593656, - 1, 4, i_8e3b5f79fe593656, nullptr, nullptr, { &s_8e3b5f79fe593656, nullptr, nullptr, 0, 0, nullptr } + 1, 4, i_8e3b5f79fe593656, nullptr, nullptr, { &s_8e3b5f79fe593656, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<50> b_9dd1f724f4614a85 = { @@ -2497,7 +2497,7 @@ static const uint16_t m_9dd1f724f4614a85[] = {1, 0}; static const uint16_t i_9dd1f724f4614a85[] = {0, 1}; const ::capnp::_::RawSchema s_9dd1f724f4614a85 = { 0x9dd1f724f4614a85, b_9dd1f724f4614a85.words, 50, d_9dd1f724f4614a85, m_9dd1f724f4614a85, - 1, 2, i_9dd1f724f4614a85, nullptr, nullptr, { &s_9dd1f724f4614a85, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_9dd1f724f4614a85, nullptr, nullptr, { &s_9dd1f724f4614a85, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<37> b_baefc9120c56e274 = { @@ -2548,7 +2548,7 @@ static const uint16_t m_baefc9120c56e274[] = {0}; static const uint16_t i_baefc9120c56e274[] = {0}; const ::capnp::_::RawSchema s_baefc9120c56e274 = { 0xbaefc9120c56e274, b_baefc9120c56e274.words, 37, d_baefc9120c56e274, m_baefc9120c56e274, - 1, 1, i_baefc9120c56e274, nullptr, nullptr, { &s_baefc9120c56e274, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_baefc9120c56e274, nullptr, nullptr, { &s_baefc9120c56e274, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<43> b_903455f06065422b = { @@ -2605,7 +2605,7 @@ static const uint16_t m_903455f06065422b[] = {0}; static const uint16_t i_903455f06065422b[] = {0}; const ::capnp::_::RawSchema s_903455f06065422b = { 0x903455f06065422b, b_903455f06065422b.words, 43, d_903455f06065422b, m_903455f06065422b, - 1, 1, i_903455f06065422b, nullptr, nullptr, { &s_903455f06065422b, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_903455f06065422b, nullptr, nullptr, { &s_903455f06065422b, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<67> b_abd73485a9636bc9 = { @@ -2686,7 +2686,7 @@ static const uint16_t m_abd73485a9636bc9[] = {1, 2, 0}; static const uint16_t i_abd73485a9636bc9[] = {1, 2, 0}; const ::capnp::_::RawSchema s_abd73485a9636bc9 = { 0xabd73485a9636bc9, b_abd73485a9636bc9.words, 67, d_abd73485a9636bc9, m_abd73485a9636bc9, - 1, 3, i_abd73485a9636bc9, nullptr, nullptr, { &s_abd73485a9636bc9, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_abd73485a9636bc9, nullptr, nullptr, { &s_abd73485a9636bc9, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_c863cd16969ee7fc = { @@ -2749,7 +2749,7 @@ static const uint16_t m_c863cd16969ee7fc[] = {1, 0}; static const uint16_t i_c863cd16969ee7fc[] = {0, 1}; const ::capnp::_::RawSchema s_c863cd16969ee7fc = { 0xc863cd16969ee7fc, b_c863cd16969ee7fc.words, 49, d_c863cd16969ee7fc, m_c863cd16969ee7fc, - 1, 2, i_c863cd16969ee7fc, nullptr, nullptr, { &s_c863cd16969ee7fc, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_c863cd16969ee7fc, nullptr, nullptr, { &s_c863cd16969ee7fc, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<305> b_ce23dcd2d7b00c9b = { @@ -3065,7 +3065,7 @@ static const uint16_t m_ce23dcd2d7b00c9b[] = {18, 1, 13, 15, 10, 11, 3, 4, 5, 2, static const uint16_t i_ce23dcd2d7b00c9b[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}; const ::capnp::_::RawSchema s_ce23dcd2d7b00c9b = { 0xce23dcd2d7b00c9b, b_ce23dcd2d7b00c9b.words, 305, nullptr, m_ce23dcd2d7b00c9b, - 0, 19, i_ce23dcd2d7b00c9b, nullptr, nullptr, { &s_ce23dcd2d7b00c9b, nullptr, nullptr, 0, 0, nullptr } + 0, 19, i_ce23dcd2d7b00c9b, nullptr, nullptr, { &s_ce23dcd2d7b00c9b, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<63> b_f1c8950dab257542 = { @@ -3143,7 +3143,7 @@ static const uint16_t m_f1c8950dab257542[] = {2, 0, 1}; static const uint16_t i_f1c8950dab257542[] = {0, 1, 2}; const ::capnp::_::RawSchema s_f1c8950dab257542 = { 0xf1c8950dab257542, b_f1c8950dab257542.words, 63, d_f1c8950dab257542, m_f1c8950dab257542, - 2, 3, i_f1c8950dab257542, nullptr, nullptr, { &s_f1c8950dab257542, nullptr, nullptr, 0, 0, nullptr } + 2, 3, i_f1c8950dab257542, nullptr, nullptr, { &s_f1c8950dab257542, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<54> b_d1958f7dba521926 = { @@ -3207,7 +3207,7 @@ static const ::capnp::_::AlignedData<54> b_d1958f7dba521926 = { static const uint16_t m_d1958f7dba521926[] = {1, 2, 5, 0, 4, 7, 6, 3}; const ::capnp::_::RawSchema s_d1958f7dba521926 = { 0xd1958f7dba521926, b_d1958f7dba521926.words, 54, nullptr, m_d1958f7dba521926, - 0, 8, nullptr, nullptr, nullptr, { &s_d1958f7dba521926, nullptr, nullptr, 0, 0, nullptr } + 0, 8, nullptr, nullptr, nullptr, { &s_d1958f7dba521926, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE CAPNP_DEFINE_ENUM(ElementSize_d1958f7dba521926, d1958f7dba521926); @@ -3282,7 +3282,7 @@ static const uint16_t m_d85d305b7d839963[] = {0, 2, 1}; static const uint16_t i_d85d305b7d839963[] = {0, 1, 2}; const ::capnp::_::RawSchema s_d85d305b7d839963 = { 0xd85d305b7d839963, b_d85d305b7d839963.words, 63, nullptr, m_d85d305b7d839963, - 0, 3, i_d85d305b7d839963, nullptr, nullptr, { &s_d85d305b7d839963, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_d85d305b7d839963, nullptr, nullptr, { &s_d85d305b7d839963, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<98> b_bfc546f6210ad7ce = { @@ -3397,7 +3397,7 @@ static const uint16_t m_bfc546f6210ad7ce[] = {2, 0, 1, 3}; static const uint16_t i_bfc546f6210ad7ce[] = {0, 1, 2, 3}; const ::capnp::_::RawSchema s_bfc546f6210ad7ce = { 0xbfc546f6210ad7ce, b_bfc546f6210ad7ce.words, 98, d_bfc546f6210ad7ce, m_bfc546f6210ad7ce, - 4, 4, i_bfc546f6210ad7ce, nullptr, nullptr, { &s_bfc546f6210ad7ce, nullptr, nullptr, 0, 0, nullptr } + 4, 4, i_bfc546f6210ad7ce, nullptr, nullptr, { &s_bfc546f6210ad7ce, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<74> b_cfea0eb02e810062 = { @@ -3485,7 +3485,7 @@ static const uint16_t m_cfea0eb02e810062[] = {1, 0, 2}; static const uint16_t i_cfea0eb02e810062[] = {0, 1, 2}; const ::capnp::_::RawSchema s_cfea0eb02e810062 = { 0xcfea0eb02e810062, b_cfea0eb02e810062.words, 74, d_cfea0eb02e810062, m_cfea0eb02e810062, - 1, 3, i_cfea0eb02e810062, nullptr, nullptr, { &s_cfea0eb02e810062, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_cfea0eb02e810062, nullptr, nullptr, { &s_cfea0eb02e810062, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<52> b_ae504193122357e5 = { @@ -3548,7 +3548,7 @@ static const uint16_t m_ae504193122357e5[] = {0, 1}; static const uint16_t i_ae504193122357e5[] = {0, 1}; const ::capnp::_::RawSchema s_ae504193122357e5 = { 0xae504193122357e5, b_ae504193122357e5.words, 52, nullptr, m_ae504193122357e5, - 0, 2, i_ae504193122357e5, nullptr, nullptr, { &s_ae504193122357e5, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_ae504193122357e5, nullptr, nullptr, { &s_ae504193122357e5, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -3560,286 +3560,426 @@ namespace capnp { namespace schema { // Node +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::_capnpPrivate::dataWordSize; constexpr uint16_t Node::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Parameter +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Parameter::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Parameter::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Parameter::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Parameter::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::NestedNode +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::NestedNode::_capnpPrivate::dataWordSize; constexpr uint16_t Node::NestedNode::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::NestedNode::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::NestedNode::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::SourceInfo +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::SourceInfo::_capnpPrivate::dataWordSize; constexpr uint16_t Node::SourceInfo::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::SourceInfo::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::SourceInfo::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::SourceInfo::Member +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::SourceInfo::Member::_capnpPrivate::dataWordSize; constexpr uint16_t Node::SourceInfo::Member::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::SourceInfo::Member::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::SourceInfo::Member::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Struct +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Struct::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Struct::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Struct::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Struct::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Enum +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Enum::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Enum::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Enum::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Enum::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Interface +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Interface::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Interface::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Interface::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Interface::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Const +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Const::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Const::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Const::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Const::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Annotation +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Annotation::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Annotation::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Annotation::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Annotation::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Field +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Field::_capnpPrivate::dataWordSize; constexpr uint16_t Field::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Field::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Field::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE -#if !defined(_MSC_VER) || defined(__clang__) +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::uint16_t Field::NO_DISCRIMINANT; #endif // Field::Slot +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Field::Slot::_capnpPrivate::dataWordSize; constexpr uint16_t Field::Slot::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Field::Slot::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Field::Slot::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Field::Group +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Field::Group::_capnpPrivate::dataWordSize; constexpr uint16_t Field::Group::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Field::Group::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Field::Group::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Field::Ordinal +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Field::Ordinal::_capnpPrivate::dataWordSize; constexpr uint16_t Field::Ordinal::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Field::Ordinal::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Field::Ordinal::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Enumerant +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Enumerant::_capnpPrivate::dataWordSize; constexpr uint16_t Enumerant::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Enumerant::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Enumerant::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Superclass +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Superclass::_capnpPrivate::dataWordSize; constexpr uint16_t Superclass::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Superclass::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Superclass::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Method +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Method::_capnpPrivate::dataWordSize; constexpr uint16_t Method::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Method::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Method::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::_capnpPrivate::dataWordSize; constexpr uint16_t Type::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::List +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::List::_capnpPrivate::dataWordSize; constexpr uint16_t Type::List::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::List::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::List::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::Enum +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::Enum::_capnpPrivate::dataWordSize; constexpr uint16_t Type::Enum::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::Enum::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::Enum::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::Struct +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::Struct::_capnpPrivate::dataWordSize; constexpr uint16_t Type::Struct::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::Struct::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::Struct::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::Interface +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::Interface::_capnpPrivate::dataWordSize; constexpr uint16_t Type::Interface::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::Interface::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::Interface::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::AnyPointer +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::AnyPointer::_capnpPrivate::dataWordSize; constexpr uint16_t Type::AnyPointer::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::AnyPointer::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::AnyPointer::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::AnyPointer::Unconstrained +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::AnyPointer::Unconstrained::_capnpPrivate::dataWordSize; constexpr uint16_t Type::AnyPointer::Unconstrained::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::AnyPointer::Unconstrained::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::AnyPointer::Unconstrained::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::AnyPointer::Parameter +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::AnyPointer::Parameter::_capnpPrivate::dataWordSize; constexpr uint16_t Type::AnyPointer::Parameter::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::AnyPointer::Parameter::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::AnyPointer::Parameter::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::AnyPointer::ImplicitMethodParameter +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::AnyPointer::ImplicitMethodParameter::_capnpPrivate::dataWordSize; constexpr uint16_t Type::AnyPointer::ImplicitMethodParameter::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::AnyPointer::ImplicitMethodParameter::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::AnyPointer::ImplicitMethodParameter::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Brand +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Brand::_capnpPrivate::dataWordSize; constexpr uint16_t Brand::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Brand::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Brand::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Brand::Scope +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Brand::Scope::_capnpPrivate::dataWordSize; constexpr uint16_t Brand::Scope::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Brand::Scope::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Brand::Scope::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Brand::Binding +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Brand::Binding::_capnpPrivate::dataWordSize; constexpr uint16_t Brand::Binding::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Brand::Binding::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Brand::Binding::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Value +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Value::_capnpPrivate::dataWordSize; constexpr uint16_t Value::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Value::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Value::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Annotation +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Annotation::_capnpPrivate::dataWordSize; constexpr uint16_t Annotation::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Annotation::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Annotation::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CapnpVersion +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CapnpVersion::_capnpPrivate::dataWordSize; constexpr uint16_t CapnpVersion::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CapnpVersion::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CapnpVersion::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CodeGeneratorRequest +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CodeGeneratorRequest::_capnpPrivate::dataWordSize; constexpr uint16_t CodeGeneratorRequest::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CodeGeneratorRequest::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CodeGeneratorRequest::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CodeGeneratorRequest::RequestedFile +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CodeGeneratorRequest::RequestedFile::_capnpPrivate::dataWordSize; constexpr uint16_t CodeGeneratorRequest::RequestedFile::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CodeGeneratorRequest::RequestedFile::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CodeGeneratorRequest::RequestedFile::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CodeGeneratorRequest::RequestedFile::Import +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CodeGeneratorRequest::RequestedFile::Import::_capnpPrivate::dataWordSize; constexpr uint16_t CodeGeneratorRequest::RequestedFile::Import::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CodeGeneratorRequest::RequestedFile::Import::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CodeGeneratorRequest::RequestedFile::Import::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/schema.capnp.h index 114d83af05f..9bdc7e70c8e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema.capnp.h @@ -6,7 +6,9 @@ #include #include -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/schema.h b/libs/EXTERNAL/capnproto/c++/src/capnp/schema.h index 5cc20b5e2ea..5eebacba28c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/schema.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/schema.h @@ -126,6 +126,10 @@ class Schema { BrandArgumentList getBrandArgumentsAtScope(uint64_t scopeId) const; // Gets the values bound to the brand parameters at the given scope. + kj::Array getGenericScopeIds() const; + // Returns the type IDs of all parent scopes that have generic parameters, to which this type is + // subject. + StructSchema asStruct() const; EnumSchema asEnum() const; InterfaceSchema asInterface() const; @@ -275,6 +279,25 @@ class StructSchema: public Schema { bool isStreamResult() const; // Convenience method to check if this is the result type of a streaming RPC method. + bool mayContainCapabilities() const { return raw->generic->mayContainCapabilities; } + // Returns true if a struct of this type may transitively contain any capabilities. I.e., are + // any of the fields an interface type, or a struct type that may in turn contain capabilities? + // + // This is meant for optimizations where various bookkeeping can possibly be skipped if it is + // known in advance that there are no capabilities. Note that this may conservatively return true + // spuriously, e.g. if it would be inconvenient to compute the correct answer. A false positive + // should never cause incorrect behavior, just potentially hurt performance. + // + // It's important to keep in mind that even if a schema has no capability-typed fields today, + // they could always be added in future versions of the schema. So, just because the schema + // doesn't contain capabilities does NOT necessarily mean that an instance of the struct can't + // contain capabilities. However, it is a pretty good hint that the application won't plan to + // use such capabilities -- for example, if there are no caps in an RPC call's response type + // according to the client's version of the schema, then the client clearly isn't going to try + // to make any pipelined calls. The server could be operating with a new version of the schema + // and could actually return capabilities, but for the client to make a pipelined call, the + // client would have to know in advance that capabilities could be returned. + private: StructSchema(Schema base): Schema(base) {} template static inline StructSchema fromImpl() { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async-test.c++ index 380b142a067..dcae6b060fb 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async-test.c++ @@ -35,6 +35,7 @@ #include #include "test-util.h" #include +#include #if _WIN32 #include @@ -382,6 +383,183 @@ TEST(SerializeAsyncTest, WriteMultipleMessagesAsync) { writeMessages(*output, msgs).wait(ioContext.waitScope); } +void writeSmallMessage(kj::OutputStream& output, kj::StringPtr text) { + capnp::MallocMessageBuilder message; + message.getRoot().getAnyPointerField().setAs(text); + writeMessage(output, message); +} + +void expectSmallMessage(MessageStream& stream, kj::StringPtr text, kj::WaitScope& waitScope) { + auto msg = stream.readMessage().wait(waitScope); + KJ_EXPECT(msg->getRoot().getAnyPointerField().getAs() == text); +} + +void writeBigMessage(kj::OutputStream& output) { + capnp::MallocMessageBuilder message(4); // first segment is small + initTestMessage(message.getRoot()); + writeMessage(output, message); +} + +void expectBigMessage(MessageStream& stream, kj::WaitScope& waitScope) { + auto msg = stream.readMessage().wait(waitScope); + checkTestMessage(msg->getRoot()); +} + +KJ_TEST("BufferedMessageStream basics") { + // Encode input data. + kj::VectorOutputStream data; + + writeSmallMessage(data, "foo"); + + KJ_EXPECT(data.getArray().size() / sizeof(word) == 4); + + // A big message (more than half a buffer) + writeBigMessage(data); + + KJ_EXPECT(data.getArray().size() / sizeof(word) > 16); + + writeSmallMessage(data, "bar"); + writeSmallMessage(data, "baz"); + writeSmallMessage(data, "qux"); + + // Run the test. + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + auto writePromise = pipe.ends[1]->write(data.getArray().begin(), data.getArray().size()); + + uint callbackCallCount = 0; + auto callback = [&](MessageReader& reader) { + ++callbackCallCount; + return false; + }; + + BufferedMessageStream stream(*pipe.ends[0], callback, 16); + expectSmallMessage(stream, "foo", waitScope); + KJ_EXPECT(callbackCallCount == 1); + + KJ_EXPECT(!writePromise.poll(waitScope)); + + expectBigMessage(stream, waitScope); + KJ_EXPECT(callbackCallCount == 1); // no callback on big message + + KJ_EXPECT(!writePromise.poll(waitScope)); + + expectSmallMessage(stream, "bar", waitScope); + KJ_EXPECT(callbackCallCount == 2); + + // All data is now in the buffer, so this part is done. + KJ_EXPECT(writePromise.poll(waitScope)); + + expectSmallMessage(stream, "baz", waitScope); + expectSmallMessage(stream, "qux", waitScope); + KJ_EXPECT(callbackCallCount == 4); + + auto eofPromise = stream.MessageStream::tryReadMessage(); + KJ_EXPECT(!eofPromise.poll(waitScope)); + + pipe.ends[1]->shutdownWrite(); + KJ_EXPECT(eofPromise.wait(waitScope) == nullptr); +} + +KJ_TEST("BufferedMessageStream fragmented reads") { + // Encode input data. + kj::VectorOutputStream data; + writeBigMessage(data); + + // Run the test. + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + auto callback = [&](MessageReader& reader) { + return false; + }; + BufferedMessageStream stream(*pipe.ends[0], callback, 16); + + // Arrange to read a big message. + auto readPromise = stream.MessageStream::tryReadMessage(); + KJ_EXPECT(!readPromise.poll(waitScope)); + + auto remainingData = data.getArray(); + + // Write 5 bytes. This won't even fulfill the first read's minBytes. + pipe.ends[1]->write(remainingData.begin(), 5).wait(waitScope); + remainingData = remainingData.slice(5, remainingData.size()); + KJ_EXPECT(!readPromise.poll(waitScope)); + + // Write 4 more. Now the MessageStream will only see the first word which contains the first + // segment size. This size is small so the MessageStream won't yet fall back to + // readEntireMessage(). + pipe.ends[1]->write(remainingData.begin(), 4).wait(waitScope); + remainingData = remainingData.slice(4, remainingData.size()); + KJ_EXPECT(!readPromise.poll(waitScope)); + + // Drip 10 more bytes. Now the MessageStream will realize that it needs to try + // readEntireMessage(). + pipe.ends[1]->write(remainingData.begin(), 10).wait(waitScope); + remainingData = remainingData.slice(10, remainingData.size()); + KJ_EXPECT(!readPromise.poll(waitScope)); + + // Give it all except the last byte. + pipe.ends[1]->write(remainingData.begin(), remainingData.size() - 1).wait(waitScope); + remainingData = remainingData.slice(remainingData.size() - 1, remainingData.size()); + KJ_EXPECT(!readPromise.poll(waitScope)); + + // Finish it off. + pipe.ends[1]->write(remainingData.begin(), 1).wait(waitScope); + KJ_ASSERT(readPromise.poll(waitScope)); + + auto msg = readPromise.wait(waitScope); + checkTestMessage(KJ_ASSERT_NONNULL(msg)->getRoot()); +} + +KJ_TEST("BufferedMessageStream many small messages") { + // Encode input data. + kj::VectorOutputStream data; + + for (auto i: kj::zeroTo(16)) { + // Intentionally make these 5 words each so they cross buffer boundaries. + writeSmallMessage(data, kj::str("12345678-", i)); + KJ_EXPECT(data.getArray().size() / sizeof(word) == (i+1) * 5); + } + + // Run the test. + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + auto writePromise = pipe.ends[1]->write(data.getArray().begin(), data.getArray().size()) + .then([&]() { + // Write some garbage at the end. + return pipe.ends[1]->write("bogus", 5); + }).then([&]() { + // EOF. + return pipe.ends[1]->shutdownWrite(); + }).eagerlyEvaluate(nullptr); + + uint callbackCallCount = 0; + auto callback = [&](MessageReader& reader) { + ++callbackCallCount; + return false; + }; + + BufferedMessageStream stream(*pipe.ends[0], callback, 16); + + for (auto i: kj::zeroTo(16)) { + // Intentionally make these 5 words each so they cross buffer boundaries. + expectSmallMessage(stream, kj::str("12345678-", i), waitScope); + KJ_EXPECT(callbackCallCount == i + 1); + } + + KJ_EXPECT_THROW(DISCONNECTED, stream.MessageStream::tryReadMessage().wait(waitScope)); + KJ_EXPECT(callbackCallCount == 16); +} + +// TODO(test): We should probably test BufferedMessageStream's FD handling here... but really it +// gets tested well enough by rpc-twoparty-test. + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async.c++ index 53b30e9ac0f..45eb1846aec 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async.c++ @@ -31,6 +31,7 @@ #endif #include "serialize-async.h" +#include "serialize.h" #include #include @@ -295,7 +296,7 @@ kj::Promise writeMessageImpl(kj::ArrayPtr> auto promise = writeFunc(arrays.pieces); // Make sure the arrays aren't freed until the write completes. - return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {})); + return promise.then([arrays=kj::mv(arrays)]() {}); } template @@ -315,7 +316,7 @@ kj::Promise writeMessagesImpl( size_t tableValsWritten = 0; size_t piecesWritten = 0; - for (int i = 0; i < messages.size(); ++i) { + for (auto i : kj::indices(messages)) { const size_t tableValsToWrite = tableSizeForSegments(messages[i].size()); const size_t piecesToWrite = messages[i].size() + 1; fillWriteArraysWithMessage( @@ -360,15 +361,54 @@ kj::Promise writeMessages( kj::Promise writeMessages( kj::AsyncOutputStream& output, kj::ArrayPtr builders) { auto messages = kj::heapArray>>(builders.size()); - for (int i = 0; i < builders.size(); ++i) { + for (auto i : kj::indices(builders)) { messages[i] = builders[i]->getSegmentsForOutput(); } return writeMessages(output, messages); } +kj::Promise MessageStream::writeMessages(kj::ArrayPtr messages) { + if (messages.size() == 0) return kj::READY_NOW; + kj::ArrayPtr remainingMessages; + + auto writeProm = [&]() { + if (messages[0].fds.size() > 0) { + // We have a message with FDs attached. We need to write any bare messages we've accumulated, + // if any, then write the message with FDs, then continue on with any remaining messages. + + if (messages.size() > 1) { + remainingMessages = messages.slice(1, messages.size()); + } + + return writeMessage(messages[0].fds, messages[0].segments); + } else { + kj::Vector>> bareMessages(messages.size()); + for(auto i : kj::zeroTo(messages.size())) { + if (messages[i].fds.size() > 0) { + break; + } + bareMessages.add(messages[i].segments); + } + + if (messages.size() > bareMessages.size()) { + remainingMessages = messages.slice(bareMessages.size(), messages.size()); + } + return writeMessages(bareMessages.asPtr()).attach(kj::mv(bareMessages)); + } + }(); + + if (remainingMessages.size() > 0) { + return writeProm.then([this, remainingMessages]() mutable { + return writeMessages(remainingMessages); + }); + } else { + return writeProm; + } +} + kj::Promise MessageStream::writeMessages(kj::ArrayPtr builders) { auto messages = kj::heapArray>>(builders.size()); - for (int i = 0; i < builders.size(); ++i) { + for (auto i : kj::indices(builders)) { messages[i] = builders[i]->getSegmentsForOutput(); } return writeMessages(messages); @@ -502,4 +542,301 @@ kj::Promise MessageStream::readMessage( }); } +// ======================================================================================= + +class BufferedMessageStream::MessageReaderImpl: public FlatArrayMessageReader { +public: + MessageReaderImpl(BufferedMessageStream& parent, kj::ArrayPtr data, + ReaderOptions options) + : FlatArrayMessageReader(data, options), state(&parent) { + KJ_DASSERT(!parent.hasOutstandingShortLivedMessage); + parent.hasOutstandingShortLivedMessage = true; + } + MessageReaderImpl(kj::Array&& ownBuffer, ReaderOptions options) + : FlatArrayMessageReader(ownBuffer, options), state(kj::mv(ownBuffer)) {} + MessageReaderImpl(kj::ArrayPtr scratchBuffer, ReaderOptions options) + : FlatArrayMessageReader(scratchBuffer, options) {} + + ~MessageReaderImpl() noexcept(false) { + KJ_IF_MAYBE(parent, state.tryGet()) { + (*parent)->hasOutstandingShortLivedMessage = false; + } + } + +private: + kj::OneOf> state; + // * BufferedMessageStream* if this reader aliases the original buffer. + // * kj::Array if this reader owns its own backing buffer. +}; + +BufferedMessageStream::BufferedMessageStream( + kj::AsyncIoStream& stream, IsShortLivedCallback isShortLivedCallback, + size_t bufferSizeInWords) + : stream(stream), isShortLivedCallback(kj::mv(isShortLivedCallback)), + buffer(kj::heapArray(bufferSizeInWords)), + beginData(buffer.begin()), beginAvailable(buffer.asBytes().begin()) {} + +BufferedMessageStream::BufferedMessageStream( + kj::AsyncCapabilityStream& stream, IsShortLivedCallback isShortLivedCallback, + size_t bufferSizeInWords) + : stream(stream), capStream(stream), isShortLivedCallback(kj::mv(isShortLivedCallback)), + buffer(kj::heapArray(bufferSizeInWords)), + beginData(buffer.begin()), beginAvailable(buffer.asBytes().begin()) {} + +kj::Promise> BufferedMessageStream::tryReadMessage( + kj::ArrayPtr fdSpace, ReaderOptions options, kj::ArrayPtr scratchSpace) { + return tryReadMessageImpl(fdSpace, 0, options, scratchSpace); +} + +kj::Promise BufferedMessageStream::writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) { + KJ_IF_MAYBE(cs, capStream) { + return capnp::writeMessage(*cs, fds, segments); + } else { + return capnp::writeMessage(stream, segments); + } +} + +kj::Promise BufferedMessageStream::writeMessages( + kj::ArrayPtr>> messages) { + return capnp::writeMessages(stream, messages); +} + +kj::Maybe BufferedMessageStream::getSendBufferSize() { + return capnp::getSendBufferSize(stream); +} + +kj::Promise BufferedMessageStream::end() { + stream.shutdownWrite(); + return kj::READY_NOW; +} + +kj::Promise> BufferedMessageStream::tryReadMessageImpl( + kj::ArrayPtr fdSpace, size_t fdsSoFar, + ReaderOptions options, kj::ArrayPtr scratchSpace) { + KJ_REQUIRE(!hasOutstandingShortLivedMessage, + "can't read another message while the previous short-lived message still exists"); + + kj::byte* beginDataBytes = reinterpret_cast(beginData); + size_t dataByteSize = beginAvailable - beginDataBytes; + kj::ArrayPtr data = kj::arrayPtr(beginData, dataByteSize / sizeof(word)); + + size_t expected = expectedSizeInWordsFromPrefix(data); + + if (!leftoverFds.empty() && expected * sizeof(word) == dataByteSize) { + // We're about to return a message that consumes the rest of the data in the buffer, and + // `leftoverFds` is non-empty. Those FDs are considered attached to whatever message contains + // the last byte in the buffer. That's us! Let's consume them. + + // `fdsSoFar` must be empty here because we shouldn't have performed any reads while + // `leftoverFds` was non-empty, so there shouldn't have been any other chance to add FDs to + // `fdSpace`. + KJ_ASSERT(fdsSoFar == 0); + + fdsSoFar = kj::min(leftoverFds.size(), fdSpace.size()); + for (auto i: kj::zeroTo(fdsSoFar)) { + fdSpace[i] = kj::mv(leftoverFds[i]); + } + leftoverFds.clear(); + } + + if (expected <= data.size()) { + // The buffer contains at least one whole message, which we can just return without reading + // any more data. + + auto msgData = kj::arrayPtr(beginData, expected); + auto reader = kj::heap(*this, msgData, options); + if (!isShortLivedCallback(*reader)) { + // This message is long-lived, so we must make a copy to get it out of our buffer. + if (msgData.size() <= scratchSpace.size()) { + // Oh hey, we can use the provided scratch space. + memcpy(scratchSpace.begin(), msgData.begin(), msgData.asBytes().size()); + reader = kj::heap(scratchSpace, options); + } else { + auto ownMsgData = kj::heapArray(msgData.size()); + memcpy(ownMsgData.begin(), msgData.begin(), msgData.asBytes().size()); + reader = kj::heap(kj::mv(ownMsgData), options); + } + } + + beginData += expected; + if (reinterpret_cast(beginData) == beginAvailable) { + // The buffer is empty. Let's opportunistically reset the pointers. + beginData = buffer.begin(); + beginAvailable = buffer.asBytes().begin(); + } else if (fdsSoFar > 0) { + // The buffer is NOT empty, and we received FDs when we were filling it. These FDs must + // actually belong to the last message in the buffer, because when the OS returns FDs + // attached to a read, it will make sure the read does not extend past the last byte to + // which those FDs were attached. + // + // So, we must set these FDs aside for the moment. + for (auto i: kj::zeroTo(fdsSoFar)) { + leftoverFds.add(kj::mv(fdSpace[i])); + } + fdsSoFar = 0; + } + + return kj::Maybe(MessageReaderAndFds { + kj::mv(reader), + fdSpace.slice(0, fdsSoFar) + }); + } + + // At this point, the buffer doesn't contain a complete message. We are going to need to perform + // a read. + + if (expected > buffer.size() / 2 || fdsSoFar > 0) { + // Read this message into its own separately-allocated buffer. We do this for: + // - Big messages, because they might not fit in the buffer and because big messages are + // almost certainly going to be long-lived and so would require a copy later anyway. + // - Messages where we've already received some FDs, because these are also almost certainly + // long-lived, and we want to avoid accidentally reading into the next message since we + // could end up receiving FDs that were intended for that one. + // + // Optimization note: You might argue that if the expected size is more than half the buffer, + // but still less than the *whole* buffer, then we should still try to read into the buffer + // first. However, keep in mind that in the RPC system, all short-lived messages are + // relatively small, and hence we can assume that since this is a large message, it will + // end up being long-lived. Long-lived messages need to be copied out into their own buffer + // at some point anyway. So we might as well go ahead and allocate that separate buffer + // now, and read directly into it, rather than try to use the shared buffer. We choose to + // use buffer.size() / 2 as the cutoff because that ensures that we won't try to move the + // bytes of a known-large message to the beginning of the buffer (see next if() after this + // one). + + auto prefix = kj::arrayPtr(beginDataBytes, dataByteSize); + + // We are consuming everything in the buffer here, so we can reset the pointers so the + // buffer appears empty on the next message read after this. + beginData = buffer.begin(); + beginAvailable = buffer.asBytes().begin(); + + return readEntireMessage(prefix, expected, fdSpace, fdsSoFar, options); + } + + // Set minBytes to at least complete the current message. + size_t minBytes = expected * sizeof(word) - dataByteSize; + + // minBytes must be less than half the buffer otherwise we would have taken the + // readEntireMessage() branch above. + KJ_DASSERT(minBytes <= buffer.asBytes().size() / 2); + + // Set maxBytes to the space we have available in the buffer. + size_t maxBytes = buffer.asBytes().end() - beginAvailable; + + if (maxBytes < buffer.asBytes().size() / 2) { + // We have less than half the buffer remaining to read into. Move the buffered data to the + // beginning of the buffer to make more space. + memmove(buffer.begin(), beginData, dataByteSize); + beginData = buffer.begin(); + beginDataBytes = buffer.asBytes().begin(); + beginAvailable = beginDataBytes + dataByteSize; + + maxBytes = buffer.asBytes().end() - beginAvailable; + } + + // maxBytes must now be more than half the buffer, because if it weren't we would have moved + // the existing data above, and the existing data cannot be more than half the buffer because + // if it were we would have taken the readEntireMesage() path earlier. + KJ_DASSERT(maxBytes >= buffer.asBytes().size() / 2); + + // Since minBytes is less that half the buffer and maxBytes is more then half, minBytes is + // definitely less than maxBytes. + KJ_DASSERT(minBytes <= maxBytes); + + // Read from underlying stream. + return tryReadWithFds(beginAvailable, minBytes, maxBytes, + fdSpace.begin() + fdsSoFar, fdSpace.size() - fdsSoFar) + .then([this,minBytes,fdSpace,fdsSoFar,options,scratchSpace] + (kj::AsyncCapabilityStream::ReadResult result) mutable + -> kj::Promise> { + // Account for new data received in the buffer. + beginAvailable += result.byteCount; + + if (result.byteCount < minBytes) { + // Didn't reach minBytes, so we must have hit EOF. That's legal as long as it happened on + // a clean message boundray. + if (beginAvailable > reinterpret_cast(beginData)) { + // We had received a partial message before EOF, so this should be considered an error. + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, + "stream disconnected prematurely")); + } + return kj::Maybe(nullptr); + } + + // Loop! + return tryReadMessageImpl(fdSpace, fdsSoFar + result.capCount, options, scratchSpace); + }); +} + +kj::Promise> BufferedMessageStream::readEntireMessage( + kj::ArrayPtr prefix, size_t expectedSizeInWords, + kj::ArrayPtr fdSpace, size_t fdsSoFar, + ReaderOptions options) { + KJ_REQUIRE(expectedSizeInWords <= options.traversalLimitInWords, + "incoming RPC message exceeds size limit"); + + auto msgBuffer = kj::heapArray(expectedSizeInWords); + + memcpy(msgBuffer.asBytes().begin(), prefix.begin(), prefix.size()); + + size_t bytesRemaining = msgBuffer.asBytes().size() - prefix.size(); + + // TODO(perf): If we had scatter-read API support, we could optimistically try to read additional + // bytes into the shared buffer, to save syscalls when a big message is immediately followed + // by small messages. + auto promise = tryReadWithFds( + msgBuffer.asBytes().begin() + prefix.size(), bytesRemaining, bytesRemaining, + fdSpace.begin() + fdsSoFar, fdSpace.size() - fdsSoFar); + return promise + .then([this, msgBuffer = kj::mv(msgBuffer), fdSpace, fdsSoFar, options, bytesRemaining] + (kj::AsyncCapabilityStream::ReadResult result) mutable + -> kj::Promise> { + fdsSoFar += result.capCount; + + if (result.byteCount < bytesRemaining) { + // Received EOF during message. + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "stream disconnected prematurely")); + return kj::Maybe(nullptr); + } + + size_t newExpectedSize = expectedSizeInWordsFromPrefix(msgBuffer); + if (newExpectedSize > msgBuffer.size()) { + // Unfortunately, the predicted size increased. This can happen if the segment table had + // not been fully received when we generated the first prediction. This should be rare + // (most segment tables are small and should be received all at once), but in this case we + // will need to make a whole new copy of the message. + // + // We recurse here, but this should never recurse more than once, since we should always + // have the entire segment table by this point and therefore the expected size is now final. + // + // TODO(perf): Technically it's guaranteed that the original expectation should have stopped + // at the boundary between two segments, so with a clever MesnsageReader implementation + // we could actually read the rest of the message into a second buffer, avoiding the copy. + // Unclear if it's worth the effort to implement this. + return readEntireMessage(msgBuffer.asBytes(), newExpectedSize, fdSpace, fdsSoFar, options); + } + + return kj::Maybe(MessageReaderAndFds { + kj::heap(kj::mv(msgBuffer), options), + fdSpace.slice(0, fdsSoFar) + }); + }); +} + +kj::Promise BufferedMessageStream::tryReadWithFds( + void* buffer, size_t minBytes, size_t maxBytes, kj::AutoCloseFd* fdBuffer, size_t maxFds) { + KJ_IF_MAYBE(cs, capStream) { + return cs->tryReadWithFds(buffer, minBytes, maxBytes, fdBuffer, maxFds); + } else { + // Regular byte stream, no FDs. + return stream.tryRead(buffer, minBytes, maxBytes) + .then([](size_t amount) mutable -> kj::AsyncCapabilityStream::ReadResult { + return { amount, 0 }; + }); + } +} + } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async.h b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async.h index 9bee9ce4c3c..cd661d78091 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-async.h @@ -22,6 +22,7 @@ #pragma once #include +#include #include "message.h" CAPNP_BEGIN_HEADER @@ -33,6 +34,11 @@ struct MessageReaderAndFds { kj::ArrayPtr fds; }; +struct MessageAndFds { + kj::ArrayPtr> segments; + kj::ArrayPtr fds; +}; + class MessageStream { // Interface over which messages can be sent and received; virtualizes // the functionality above. @@ -76,6 +82,9 @@ class MessageStream { KJ_WARN_UNUSED_RESULT; // Equivalent to the above with fds = nullptr. + kj::Promise writeMessages( + kj::ArrayPtr messages) + KJ_WARN_UNUSED_RESULT; virtual kj::Promise writeMessages( kj::ArrayPtr>> messages) KJ_WARN_UNUSED_RESULT = 0; @@ -115,6 +124,10 @@ class AsyncIoMessageStream final: public MessageStream { kj::Maybe getSendBufferSize() override; kj::Promise end() override; + + // Make sure the overridden virtual methods don't hide the non-virtual methods. + using MessageStream::tryReadMessage; + using MessageStream::writeMessage; private: kj::AsyncIoStream& stream; }; @@ -135,10 +148,92 @@ class AsyncCapabilityMessageStream final: public MessageStream { kj::ArrayPtr>> messages) override; kj::Maybe getSendBufferSize() override; kj::Promise end() override; + + // Make sure the overridden virtual methods don't hide the non-virtual methods. + using MessageStream::tryReadMessage; + using MessageStream::writeMessage; private: kj::AsyncCapabilityStream& stream; }; +class BufferedMessageStream final: public MessageStream { + // A MessageStream that reads into a buffer in the hopes of receiving multiple messages in a + // single system call. Compared to the other implementations, this implementation is expected + // to be faster when reading from an OS stream (but probably not when reading from an in-memory + // async pipe). It has the down sides of using more memory (for the buffer) and requiring extra + // copies. + +public: + using IsShortLivedCallback = kj::Function; + // Callback function which decides whether a message will be "short-lived", meaning that it is + // guaranteed to be dropped before the next message is read. The stream uses this as an + // optimization to decide whether it can return a MessageReader pointing into the buffer, which + // will be reused for future reads. For long-lived messages, the stream must copy the content + // into a separate buffer. + + explicit BufferedMessageStream( + kj::AsyncIoStream& stream, IsShortLivedCallback isShortLivedCallback, + size_t bufferSizeInWords = 8192); + explicit BufferedMessageStream( + kj::AsyncCapabilityStream& stream, IsShortLivedCallback isShortLivedCallback, + size_t bufferSizeInWords = 8192); + + // Implements MessageStream + kj::Promise> tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr) override; + kj::Promise writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) override; + kj::Promise writeMessages( + kj::ArrayPtr>> messages) override; + kj::Maybe getSendBufferSize() override; + kj::Promise end() override; + + // Make sure the overridden virtual methods don't hide the non-virtual methods. + using MessageStream::tryReadMessage; + using MessageStream::writeMessage; + +private: + kj::AsyncIoStream& stream; + kj::Maybe capStream; + IsShortLivedCallback isShortLivedCallback; + + kj::Array buffer; + + word* beginData; + // Pointer to location in `buffer` where the next message starts. This is always on a word + // boundray since messages are always a whole number of words. + + kj::byte* beginAvailable; + // Pointer to the location in `buffer` where unused buffer space begins, i.e. immediately after + // the last byte read. + + kj::Vector leftoverFds; + // FDs which were accidentally read too early. These are always connected to the last message + // in the buffer, since the OS would not have allowed us to read past that point. + + bool hasOutstandingShortLivedMessage = false; + + kj::Promise> tryReadMessageImpl( + kj::ArrayPtr fdSpace, size_t fdsSoFar, + ReaderOptions options, kj::ArrayPtr scratchSpace); + + kj::Promise> readEntireMessage( + kj::ArrayPtr prefix, size_t expectedSizeInWords, + kj::ArrayPtr fdSpace, size_t fdsSoFar, + ReaderOptions options); + // Given a message prefix and expected size of the whole message, read the entire message into + // a single array and return it. + + kj::Promise tryReadWithFds( + void* buffer, size_t minBytes, size_t maxBytes, kj::AutoCloseFd* fdBuffer, size_t maxFds); + // Executes AsyncCapabilityStream::tryReadWithFds() on the underlying stream, or falls back to + // AsyncIoStream::tryRead() if it's not a capability stream. + + class MessageReaderImpl; +}; + // ----------------------------------------------------------------------------- // Stand-alone functions for reading & writing messages on AsyncInput/AsyncOutputStreams. // diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-packed.h b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-packed.h index 99131f4e8fb..a0329b1300e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-packed.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-packed.h @@ -35,7 +35,7 @@ class PackedInputStream: public kj::InputStream { public: explicit PackedInputStream(kj::BufferedInputStream& inner); - KJ_DISALLOW_COPY(PackedInputStream); + KJ_DISALLOW_COPY_AND_MOVE(PackedInputStream); ~PackedInputStream() noexcept(false); // implements InputStream ------------------------------------------ @@ -50,7 +50,7 @@ class PackedOutputStream: public kj::OutputStream { // An output stream that packs data. Buffers passed to `write()` must be word-aligned. public: explicit PackedOutputStream(kj::BufferedOutputStream& inner); - KJ_DISALLOW_COPY(PackedOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(PackedOutputStream); ~PackedOutputStream() noexcept(false); // implements OutputStream ----------------------------------------- @@ -66,7 +66,7 @@ class PackedMessageReader: private _::PackedInputStream, public InputStreamMessa public: PackedMessageReader(kj::BufferedInputStream& inputStream, ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr); - KJ_DISALLOW_COPY(PackedMessageReader); + KJ_DISALLOW_COPY_AND_MOVE(PackedMessageReader); ~PackedMessageReader() noexcept(false); }; @@ -83,7 +83,7 @@ class PackedFdMessageReader: private kj::FdInputStream, private kj::BufferedInpu kj::ArrayPtr scratchSpace = nullptr); // Read a message from a file descriptor, taking ownership of the descriptor. - KJ_DISALLOW_COPY(PackedFdMessageReader); + KJ_DISALLOW_COPY_AND_MOVE(PackedFdMessageReader); ~PackedFdMessageReader() noexcept(false); }; diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-text-test.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-text-test.c++ index c92838c2ffd..8ac42858446 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-text-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize-text-test.c++ @@ -142,6 +142,50 @@ KJ_TEST("TextCodec parse error") { exception.getDescription()); } +KJ_TEST("text format implicitly coerces struct value from first field type") { + // We don't actually use TextCodec here, but rather check how the compiler handled some constants + // defined in test.capnp. It's the same parser code either way but this is easier. + + { + auto s = test::TestImpliedFirstField::Reader().getTextStruct(); + KJ_EXPECT(s.getText() == "foo"); + KJ_EXPECT(s.getI() == 321); + } + + { + auto s = test::TEST_IMPLIED_FIRST_FIELD->getTextStruct(); + KJ_EXPECT(s.getText() == "bar"); + KJ_EXPECT(s.getI() == 321); + } + +#if __GNUC__ && !__clang__ +// GCC generates a spurious warning here... +#pragma GCC diagnostic ignored "-Wmisleading-indentation" +#endif + + { + auto l = test::TEST_IMPLIED_FIRST_FIELD->getTextStructList(); + KJ_ASSERT(l.size() == 2); + + { + auto s = l[0]; + KJ_EXPECT(s.getText() == "baz"); + KJ_EXPECT(s.getI() == 321); + } + { + auto s = l[1]; + KJ_EXPECT(s.getText() == "qux"); + KJ_EXPECT(s.getI() == 123); + } + } + + { + auto s = test::TEST_IMPLIED_FIRST_FIELD->getIntGroup(); + KJ_EXPECT(s.getI() == 123); + KJ_EXPECT(s.getStr() == "corge"); + } +} + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize.c++ index df7e45e0304..abb34f7998b 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/serialize.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/serialize.c++ @@ -23,6 +23,10 @@ #include "layout.h" #include #include +#ifdef _WIN32 +#include +#include +#endif namespace capnp { @@ -301,6 +305,15 @@ void writeMessage(kj::OutputStream& output, kj::ArrayPtr> segments) { +#ifdef _WIN32 + auto oldMode = _setmode(fd, _O_BINARY); + if (oldMode != _O_BINARY) { + _setmode(fd, oldMode); + KJ_FAIL_REQUIRE("Tried to write a message to a file descriptor that is in text mode. Set the " + "file descriptor to binary mode by calling the _setmode Windows CRT function, or passing " + "_O_BINARY to _open()."); + } +#endif kj::FdOutputStream stream(fd); writeMessage(stream, segments); } diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/stream.capnp.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/stream.capnp.c++ index a5937b2557d..098f26a5b8e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/stream.capnp.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/stream.capnp.c++ @@ -28,7 +28,7 @@ static const ::capnp::_::AlignedData<17> b_995f9a3377c0b16e = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_995f9a3377c0b16e = { 0x995f9a3377c0b16e, b_995f9a3377c0b16e.words, 17, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_995f9a3377c0b16e, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_995f9a3377c0b16e, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -39,11 +39,15 @@ const ::capnp::_::RawSchema s_995f9a3377c0b16e = { namespace capnp { // StreamResult +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t StreamResult::_capnpPrivate::dataWordSize; constexpr uint16_t StreamResult::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind StreamResult::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* StreamResult::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/stream.capnp.h b/libs/EXTERNAL/capnproto/c++/src/capnp/stream.capnp.h index 4ac41404d4b..91ebcb315d2 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/stream.capnp.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/stream.capnp.h @@ -6,7 +6,9 @@ #include #include -#if CAPNP_VERSION != 9001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1000002 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/test-util.c++ b/libs/EXTERNAL/capnproto/c++/src/capnp/test-util.c++ index d16d955e950..41fc08e072d 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/test-util.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/test-util.c++ @@ -1076,7 +1076,6 @@ kj::Promise TestMoreStuffImpl::neverReturn(NeverReturnContext context) { // Also attach `cap` to the result struct to make sure that is released. context.getResults().setCapCopy(context.getParams().getCap()); - context.allowCancellation(); return kj::mv(promise); } @@ -1119,7 +1118,6 @@ kj::Promise TestMoreStuffImpl::echo(EchoContext context) { kj::Promise TestMoreStuffImpl::expectCancel(ExpectCancelContext context) { auto cap = context.getParams().getCap(); - context.allowCancellation(); return loop(0, cap, context); } @@ -1129,7 +1127,7 @@ kj::Promise TestMoreStuffImpl::loop(uint depth, test::TestInterface::Clien ADD_FAILURE() << "Looped too long, giving up."; return kj::READY_NOW; } else { - return kj::evalLater([this,depth,KJ_CPCAP(cap),KJ_CPCAP(context)]() mutable { + return kj::evalLast([this,depth,KJ_CPCAP(cap),KJ_CPCAP(context)]() mutable { return loop(depth + 1, cap, context); }); } @@ -1190,6 +1188,10 @@ kj::Promise TestMoreStuffImpl::throwException(ThrowExceptionContext contex return KJ_EXCEPTION(FAILED, "test exception"); } +kj::Promise TestMoreStuffImpl::throwRemoteException(ThrowRemoteExceptionContext context) { + return KJ_EXCEPTION(FAILED, "remote exception: test exception"); +} + #endif // !CAPNP_LITE } // namespace _ (private) diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/test-util.h b/libs/EXTERNAL/capnproto/c++/src/capnp/test-util.h index 2cf47cb8465..38b445718a0 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/test-util.h +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/test-util.h @@ -280,6 +280,8 @@ class TestMoreStuffImpl final: public test::TestMoreStuff::Server { kj::Promise throwException(ThrowExceptionContext context) override; + kj::Promise throwRemoteException(ThrowRemoteExceptionContext context) override; + private: int& callCount; int& handleCount; @@ -336,7 +338,6 @@ class TestStreamingImpl final: public test::TestStreaming::Server { } kj::Promise doStreamJ(DoStreamJContext context) override { - context.allowCancellation(); jSum += context.getParams().getJ(); if (jShouldThrow) { diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/test.capnp b/libs/EXTERNAL/capnproto/c++/src/capnp/test.capnp index 7efe6aee447..11b52f6c312 100644 --- a/libs/EXTERNAL/capnproto/c++/src/capnp/test.capnp +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/test.capnp @@ -653,6 +653,8 @@ struct TestUseGenerics $TestGenerics(Text, Data).ann("foo") { inner2Bind = (baz = "text", innerBound = (foo = (int16Field = 123))), inner2Text = (baz = "text", innerBound = (foo = (int16Field = 123))), revFoo = [12, 34, 56]); + + bindEnumList @20 :TestGenerics(List(TestEnum), Text); } struct TestEmptyStruct {} @@ -758,6 +760,13 @@ const embeddedStruct :TestAllTypes = embed "testdata/binary"; const nonAsciiText :Text = "♫ é ✓"; +const blockText :Text = + `foo bar baz + `"qux" `corge` 'grault' + "regular\"quoted\"line" + `garply\nwaldo\tfred\"plugh\"xyzzy\'thud + ; + struct TestAnyPointerConstants { anyKindAsStruct @0 :AnyPointer; anyStructAsStruct @1 :AnyStruct; @@ -774,7 +783,7 @@ const anyPointerConstants :TestAnyPointerConstants = ( struct TestListOfAny { capList @0 :List(Capability); - #listList @1 :List(AnyList); # TODO(0.10): Make List(AnyList) work correctly in C++ generated code. + #listList @1 :List(AnyList); # TODO(someday): Make List(AnyList) work correctly in C++ generated code. } interface TestInterface { @@ -814,7 +823,7 @@ interface TestCallOrder { # The input `expected` is ignored but useful for disambiguating debug logs. } -interface TestTailCallee { +interface TestTailCallee $Cxx.allowCancellation { struct TailResult { i @0 :UInt32; t @1 :Text; @@ -828,7 +837,7 @@ interface TestTailCaller { foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult; } -interface TestStreaming { +interface TestStreaming $Cxx.allowCancellation { doStreamI @0 (i :UInt32) -> stream; doStreamJ @1 (j :UInt32) -> stream; finishStream @2 () -> (totalI :UInt32, totalJ :UInt32); @@ -846,7 +855,7 @@ interface TestMoreStuff extends(TestCallOrder) { callFooWhenResolved @1 (cap :TestInterface) -> (s: Text); # Like callFoo but waits for `cap` to resolve first. - neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface); + neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface) $Cxx.allowCancellation; # Doesn't return. You should cancel it. hold @3 (cap :TestInterface) -> (); @@ -861,7 +870,7 @@ interface TestMoreStuff extends(TestCallOrder) { echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder); # Just returns the input cap. - expectCancel @7 (cap :TestInterface) -> (); + expectCancel @7 (cap :TestInterface) -> () $Cxx.allowCancellation; # evalLater()-loops forever, holding `cap`. Must be canceled. methodWithDefaults @8 (a :Text, b :UInt32 = 123, c :Text = "foo") -> (d :Text, e :Text = "bar"); @@ -884,6 +893,7 @@ interface TestMoreStuff extends(TestCallOrder) { # the second. Also creates a socketpair, writes "baz" to one end, and returns the other end. throwException @14 (); + throwRemoteException @15 (); } interface TestMembrane { @@ -892,7 +902,7 @@ interface TestMembrane { callIntercept @2 (thing :Thing, tailCall :Bool) -> Result; loopback @3 (thing :Thing) -> (thing :Thing); - waitForever @4 (); + waitForever @4 () $Cxx.allowCancellation; interface Thing { passThrough @0 () -> Result; @@ -991,3 +1001,42 @@ struct TestNameAnnotation $Cxx.name("RenamedStruct") { interface TestNameAnnotationInterface $Cxx.name("RenamedInterface") { badlyNamedMethod @0 (badlyNamedParam :UInt8 $Cxx.name("renamedParam")) $Cxx.name("renamedMethod"); } + +struct TestImpliedFirstField { + struct TextStruct { + text @0 :Text; + i @1 :UInt32 = 321; + } + + textStruct @0 :TextStruct = "foo"; + textStructList @1 :List(TextStruct); + + intGroup :group { + i @2 :UInt32; + str @3 :Text = "corge"; + } +} + +const testImpliedFirstField :TestImpliedFirstField = ( + textStruct = "bar", + textStructList = ["baz", (text = "qux", i = 123)], + intGroup = 123 +); + +struct TestCycleANoCaps { + foo @0 :TestCycleBNoCaps; +} + +struct TestCycleBNoCaps { + foo @0 :List(TestCycleANoCaps); + bar @1 :TestAllTypes; +} + +struct TestCycleAWithCaps { + foo @0 :TestCycleBWithCaps; +} + +struct TestCycleBWithCaps { + foo @0 :List(TestCycleAWithCaps); + bar @1 :TestInterface; +} diff --git a/libs/EXTERNAL/capnproto/c++/src/capnp/testdata/no-file-id.capnp.nobuild b/libs/EXTERNAL/capnproto/c++/src/capnp/testdata/no-file-id.capnp.nobuild new file mode 100644 index 00000000000..98c92e29245 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/capnp/testdata/no-file-id.capnp.nobuild @@ -0,0 +1 @@ +const foo :Text = "bar"; diff --git a/libs/EXTERNAL/capnproto/c++/src/ekam-rules b/libs/EXTERNAL/capnproto/c++/src/ekam-rules new file mode 120000 index 00000000000..ff5b3b4f985 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/ekam-rules @@ -0,0 +1 @@ +../deps/ekam/src/ekam/rules \ No newline at end of file diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/BUILD.bazel b/libs/EXTERNAL/capnproto/c++/src/kj/BUILD.bazel new file mode 100644 index 00000000000..f5527bea95c --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/BUILD.bazel @@ -0,0 +1,261 @@ +load("//:build/configure.bzl", "kj_configure") + +kj_configure() + +cc_library( + name = "kj", + srcs = [ + "arena.c++", + "array.c++", + "cidr.c++", + "common.c++", + "debug.c++", + "encoding.c++", + "exception.c++", + "filesystem.c++", + "filesystem-disk-unix.c++", + "filesystem-disk-win32.c++", + "hash.c++", + "io.c++", + "list.c++", + "main.c++", + "memory.c++", + "mutex.c++", + "parse/char.c++", + "refcount.c++", + "source-location.c++", + "string.c++", + "string-tree.c++", + "table.c++", + "test-helpers.c++", + "thread.c++", + "time.c++", + "units.c++", + ], + hdrs = [ + "arena.h", + "array.h", + "cidr.h", + "common.h", + "debug.h", + "encoding.h", + "exception.h", + "filesystem.h", + "function.h", + "hash.h", + "io.h", + "list.h", + "main.h", + "map.h", + "memory.h", + "miniposix.h", + "mutex.h", + "one-of.h", + "parse/char.h", + "parse/common.h", + "refcount.h", + "source-location.h", + "std/iostream.h", + "string.h", + "string-tree.h", + "table.h", + "test.h", + "thread.h", + "threadlocal.h", + "time.h", + "tuple.h", + "units.h", + "vector.h", + "win32-api-version.h", + "windows-sanity.h", + ], + include_prefix = "kj", + linkopts = select({ + "@platforms//os:windows": [], + ":use_libdl": [ + "-lpthread", + "-ldl", + ], + "//conditions:default": ["-lpthread"], + }), + visibility = ["//visibility:public"], + deps = [":kj-defines"], +) + +cc_library( + name = "kj-async", + srcs = [ + "async.c++", + "async-io.c++", + "async-io-unix.c++", + "async-io-win32.c++", + "async-unix.c++", + "async-win32.c++", + "timer.c++", + ], + hdrs = [ + "async.h", + "async-inl.h", + "async-io.h", + "async-io-internal.h", + "async-prelude.h", + "async-queue.h", + "async-unix.h", + "async-win32.h", + "timer.h", + ], + include_prefix = "kj", + linkopts = select({ + "@platforms//os:windows": [ + "Ws2_32.lib", + "Advapi32.lib", + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [":kj"], +) + +cc_library( + name = "kj-test", + srcs = [ + "test.c++", + ], + include_prefix = "kj", + visibility = ["//visibility:public"], + deps = [ + ":kj", + "//src/kj/compat:gtest", + ], +) + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + deps = [ + ":kj", + ":kj-async", + ":kj-test", + ], +) for f in [ + "arena-test.c++", + "array-test.c++", + "async-io-test.c++", + "async-queue-test.c++", + "async-test.c++", + "async-xthread-test.c++", + "common-test.c++", + "debug-test.c++", + "encoding-test.c++", + "exception-test.c++", + "filesystem-disk-test.c++", + "filesystem-test.c++", + "function-test.c++", + "io-test.c++", + "list-test.c++", + "map-test.c++", + "memory-test.c++", + "mutex-test.c++", + "one-of-test.c++", + "parse/char-test.c++", + "refcount-test.c++", + "std/iostream-test.c++", + "string-test.c++", + "string-tree-test.c++", + "table-test.c++", + "test-test.c++", + "threadlocal-test.c++", + "thread-test.c++", + "time-test.c++", + "tuple-test.c++", + "units-test.c++", +]] + +cc_test( + name = "async-coroutine-test", + srcs = ["async-coroutine-test.c++"], + target_compatible_with = select({ + ":use_coroutines": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kj-test", + "//src/kj/compat:kj-http", + ], +) + +cc_library( + name = "filesystem-disk-test-base", + hdrs = [ + "filesystem-disk-test.c++", + "filesystem-disk-unix.c++", + ], +) + +cc_test( + name = "filesystem-disk-generic-test", + srcs = ["filesystem-disk-generic-test.c++"], + deps = [ + ":filesystem-disk-test-base", + ":kj-test", + ], + target_compatible_with = [ + "@platforms//os:linux", + ], +) + +cc_test( + name = "filesystem-disk-old-kernel-test", + srcs = ["filesystem-disk-old-kernel-test.c++"], + deps = [ + ":filesystem-disk-test-base", + ":kj-test", + ], + target_compatible_with = [ + "@platforms//os:linux", + ], +) + +cc_test( + name = "async-os-test", + srcs = select({ + "@platforms//os:windows": ["async-win32-test.c++"], + "//conditions:default": ["async-unix-test.c++"], + }), + deps = [ + ":kj", + ":kj-async", + ":kj-test", + ], +) + +cc_library( + name = "async-os-xthread-test-base", + hdrs = ["async-xthread-test.c++"], +) + +cc_test( + name = "async-os-xthread-test", + srcs = select({ + "@platforms//os:windows": ["async-win32-xthread-test.c++"], + "//conditions:default": ["async-unix-xthread-test.c++"], + }), + deps = [ + ":async-os-xthread-test-base", + ":kj-async", + ":kj-test", + ], +) + +cc_test( + name = "exception-override-symbolizer-test", + srcs = ["exception-override-symbolizer-test.c++"], + deps = [ + ":kj", + ":kj-test", + ], + linkstatic = True, + target_compatible_with = [ + "@platforms//os:linux", + ], +) diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/CMakeLists.txt b/libs/EXTERNAL/capnproto/c++/src/kj/CMakeLists.txt index 813fac4deed..980c53e34c8 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/CMakeLists.txt +++ b/libs/EXTERNAL/capnproto/c++/src/kj/CMakeLists.txt @@ -3,6 +3,7 @@ set(kj_sources_lite array.c++ + cidr.c++ list.c++ common.c++ debug.c++ @@ -37,6 +38,7 @@ else() endif() set(kj_headers + cidr.h common.h units.h memory.h @@ -64,6 +66,7 @@ set(kj_headers filesystem.h time.h main.h + win32-api-version.h windows-sanity.h ) set(kj-parse_headers @@ -84,8 +87,9 @@ endif() #make sure the lite flag propagates to all users (internal + external) of this library target_compile_definitions(kj PUBLIC ${CAPNP_LITE_FLAG}) #make sure external consumers don't need to manually set the include dirs -target_include_directories(kj INTERFACE - $ +get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) +target_include_directories(kj PUBLIC + $ $ ) # Ensure the library has a version set to match autotools build @@ -130,12 +134,22 @@ set(kj-async_headers async-win32.h async-io.h async-queue.h + cidr.h timer.h ) if(NOT CAPNP_LITE) add_library(kj-async ${kj-async_sources}) add_library(CapnProto::kj-async ALIAS kj-async) target_link_libraries(kj-async PUBLIC kj) + if(WITH_FIBERS) + target_compile_definitions(kj-async PUBLIC KJ_USE_FIBERS) + if(_WITH_LIBUCONTEXT) + target_link_libraries(kj-async PUBLIC PkgConfig::libucontext) + endif() + else() + target_compile_definitions(kj-async PUBLIC KJ_USE_FIBERS=0) + endif() + if(UNIX) # external clients of this library need to link to pthreads target_compile_options(kj-async INTERFACE "-pthread") @@ -161,7 +175,12 @@ set(kj-http_headers if(NOT CAPNP_LITE) add_library(kj-http ${kj-http_sources}) add_library(CapnProto::kj-http ALIAS kj-http) - target_link_libraries(kj-http PUBLIC kj-async kj) + if(WITH_ZLIB) + target_compile_definitions(kj-http PRIVATE KJ_HAS_ZLIB) + target_link_libraries(kj-http PUBLIC kj-async kj ZLIB::ZLIB) + else() + target_link_libraries(kj-http PUBLIC kj-async kj) + endif() # Ensure the library has a version set to match autotools build set_target_properties(kj-http PROPERTIES VERSION ${VERSION}) install(TARGETS kj-http ${INSTALL_TARGETS_DEFAULT_ARGS}) @@ -169,51 +188,51 @@ if(NOT CAPNP_LITE) endif() # kj-tls ====================================================================== -set(kj-tls_sources - compat/readiness-io.c++ - compat/tls.c++ -) -set(kj-tls_headers - compat/readiness-io.h - compat/tls.h -) -if(NOT CAPNP_LITE) - add_library(kj-tls ${kj-tls_sources}) - add_library(CapnProto::kj-tls ALIAS kj-tls) - target_link_libraries(kj-tls PUBLIC kj-async) - if (WITH_OPENSSL) +if(WITH_OPENSSL) + set(kj-tls_sources + compat/readiness-io.c++ + compat/tls.c++ + ) + set(kj-tls_headers + compat/readiness-io.h + compat/tls.h + ) + if(NOT CAPNP_LITE) + add_library(kj-tls ${kj-tls_sources}) + add_library(CapnProto::kj-tls ALIAS kj-tls) + target_link_libraries(kj-tls PUBLIC kj-async) + target_compile_definitions(kj-tls PRIVATE KJ_HAS_OPENSSL) target_link_libraries(kj-tls PRIVATE OpenSSL::SSL OpenSSL::Crypto) + + # Ensure the library has a version set to match autotools build + set_target_properties(kj-tls PROPERTIES VERSION ${VERSION}) + install(TARGETS kj-tls ${INSTALL_TARGETS_DEFAULT_ARGS}) + install(FILES ${kj-tls_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/compat") endif() - # Ensure the library has a version set to match autotools build - set_target_properties(kj-tls PROPERTIES VERSION ${VERSION}) - install(TARGETS kj-tls ${INSTALL_TARGETS_DEFAULT_ARGS}) - install(FILES ${kj-tls_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/compat") endif() # kj-gzip ====================================================================== -set(kj-gzip_sources - compat/gzip.c++ -) -set(kj-gzip_headers - compat/gzip.h -) -if(NOT CAPNP_LITE) - add_library(kj-gzip ${kj-gzip_sources}) - add_library(CapnProto::kj-gzip ALIAS kj-gzip) +if(WITH_ZLIB) + set(kj-gzip_sources + compat/gzip.c++ + ) + set(kj-gzip_headers + compat/gzip.h + ) + if(NOT CAPNP_LITE) + add_library(kj-gzip ${kj-gzip_sources}) + add_library(CapnProto::kj-gzip ALIAS kj-gzip) - find_package(ZLIB) - if(ZLIB_FOUND) - add_definitions(-D KJ_HAS_ZLIB=1) - include_directories(${ZLIB_INCLUDE_DIRS}) - target_link_libraries(kj-gzip PUBLIC kj-async kj ${ZLIB_LIBRARIES}) - endif() + target_compile_definitions(kj-gzip PRIVATE KJ_HAS_ZLIB) + target_link_libraries(kj-gzip PUBLIC kj-async kj ZLIB::ZLIB) - # Ensure the library has a version set to match autotools build - set_target_properties(kj-gzip PROPERTIES VERSION ${VERSION}) - install(TARGETS kj-gzip ${INSTALL_TARGETS_DEFAULT_ARGS}) - install(FILES ${kj-gzip_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/compat") + # Ensure the library has a version set to match autotools build + set_target_properties(kj-gzip PROPERTIES VERSION ${VERSION}) + install(TARGETS kj-gzip ${INSTALL_TARGETS_DEFAULT_ARGS}) + install(FILES ${kj-gzip_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/compat") + endif() endif() # Tests ======================================================================== @@ -228,6 +247,8 @@ if(BUILD_TESTING) table-test.c++ map-test.c++ exception-test.c++ + # this test overrides symbolizer and has to be linked separately + # exception-override-symbolizer-test.c++ debug-test.c++ io-test.c++ mutex-test.c++ @@ -245,6 +266,7 @@ if(BUILD_TESTING) add_executable(kj-heavy-tests async-test.c++ async-xthread-test.c++ + async-coroutine-test.c++ async-unix-test.c++ async-unix-xthread-test.c++ async-win32-test.c++ @@ -268,11 +290,22 @@ if(BUILD_TESTING) compat/gzip-test.c++ compat/tls-test.c++ ) - target_link_libraries(kj-heavy-tests kj-http kj-gzip kj-tls kj-async kj-test kj) - if (WITH_OPENSSL) - set_source_files_properties(compat/tls-test.c++ - PROPERTIES - COMPILE_DEFINITIONS KJ_HAS_OPENSSL + target_link_libraries(kj-heavy-tests kj-http kj-async kj-test kj) + if(WITH_OPENSSL) + target_link_libraries(kj-heavy-tests kj-tls) + # tls-test.c++ needs to use OpenSSL directly. + target_link_libraries(kj-heavy-tests OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(kj-heavy-tests PRIVATE KJ_HAS_OPENSSL) + set_property( + SOURCE compat/tls-test.c++ + APPEND PROPERTY COMPILE_DEFINITIONS KJ_HAS_OPENSSL + ) + endif() + if(WITH_ZLIB) + target_link_libraries(kj-heavy-tests kj-gzip) + set_property( + SOURCE compat/gzip-test.c++ + APPEND PROPERTY COMPILE_DEFINITIONS KJ_HAS_ZLIB ) endif() add_dependencies(check kj-heavy-tests) diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/arena.h b/libs/EXTERNAL/capnproto/c++/src/kj/arena.h index 63e0e31ed03..a16b2911216 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/arena.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/arena.h @@ -47,7 +47,7 @@ class Arena { explicit Arena(ArrayPtr scratch); // Allocates from the given scratch space first, only resorting to the heap when it runs out. - KJ_DISALLOW_COPY(Arena); + KJ_DISALLOW_COPY_AND_MOVE(Arena); ~Arena() noexcept(false); template @@ -134,11 +134,11 @@ class Arena { template T& Arena::allocate(Params&&... params) { T& result = *reinterpret_cast(allocateBytes( - sizeof(T), alignof(T), !__has_trivial_destructor(T))); - if (!__has_trivial_constructor(T) || sizeof...(Params) > 0) { + sizeof(T), alignof(T), !KJ_HAS_TRIVIAL_DESTRUCTOR(T))); + if (!KJ_HAS_TRIVIAL_CONSTRUCTOR(T) || sizeof...(Params) > 0) { ctor(result, kj::fwd(params)...); } - if (!__has_trivial_destructor(T)) { + if (!KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { setDestructor(&result, &destroyObject); } return result; @@ -146,11 +146,11 @@ T& Arena::allocate(Params&&... params) { template ArrayPtr Arena::allocateArray(size_t size) { - if (__has_trivial_destructor(T)) { + if (KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { ArrayPtr result = arrayPtr(reinterpret_cast(allocateBytes( sizeof(T) * size, alignof(T), false)), size); - if (!__has_trivial_constructor(T)) { + if (!KJ_HAS_TRIVIAL_CONSTRUCTOR(T)) { for (size_t i = 0; i < size; i++) { ctor(result[i]); } @@ -165,7 +165,7 @@ ArrayPtr Arena::allocateArray(size_t size) { arrayPtr(reinterpret_cast(reinterpret_cast(base) + prefixSize), size); setDestructor(base, &destroyArray); - if (__has_trivial_constructor(T)) { + if (KJ_HAS_TRIVIAL_CONSTRUCTOR(T)) { tag = size; } else { // In case of constructor exceptions, we need the tag to end up storing the number of objects @@ -183,7 +183,7 @@ ArrayPtr Arena::allocateArray(size_t size) { template Own Arena::allocateOwn(Params&&... params) { T& result = *reinterpret_cast(allocateBytes(sizeof(T), alignof(T), false)); - if (!__has_trivial_constructor(T) || sizeof...(Params) > 0) { + if (!KJ_HAS_TRIVIAL_CONSTRUCTOR(T) || sizeof...(Params) > 0) { ctor(result, kj::fwd(params)...); } return Own(&result, DestructorOnlyDisposer::instance); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/array-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/array-test.c++ index d361b65cab1..24e9d748dc8 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/array-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/array-test.c++ @@ -378,7 +378,7 @@ TEST(Array, ReleaseAsBytesOrChars) { } } -#if __cplusplus > 201402L +#if KJ_CPP_STD > 201402L KJ_TEST("kj::arr()") { kj::Array array = kj::arr(kj::str("foo"), kj::str(123)); KJ_EXPECT(array == kj::ArrayPtr({"foo", "123"})); @@ -386,7 +386,7 @@ KJ_TEST("kj::arr()") { struct ImmovableInt { ImmovableInt(int i): i(i) {} - KJ_DISALLOW_COPY(ImmovableInt); + KJ_DISALLOW_COPY_AND_MOVE(ImmovableInt); int i; }; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/array.h b/libs/EXTERNAL/capnproto/c++/src/kj/array.h index fdd7d4cf504..3932f9f4f3e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/array.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/array.h @@ -56,7 +56,7 @@ class ArrayDisposer { // an exception. private: - template + template struct Dispose_; }; @@ -74,7 +74,7 @@ class ExceptionSafeArrayUtil { : pos(reinterpret_cast(ptr) + elementSize * constructedElementCount), elementSize(elementSize), constructedElementCount(constructedElementCount), destroyElement(destroyElement) {} - KJ_DISALLOW_COPY(ExceptionSafeArrayUtil); + KJ_DISALLOW_COPY_AND_MOVE(ExceptionSafeArrayUtil); inline ~ExceptionSafeArrayUtil() noexcept(false) { if (constructedElementCount > 0) destroyAll(); @@ -285,8 +285,8 @@ class HeapArrayDisposer final: public ArrayDisposer { virtual void disposeImpl(void* firstElement, size_t elementSize, size_t elementCount, size_t capacity, void (*destroyElement)(void*)) const override; - template + template struct Allocate_; }; @@ -417,7 +417,7 @@ class ArrayBuilder { KJ_IREQUIRE(size <= this->size(), "can't use truncate() to expand"); T* target = ptr + size; - if (__has_trivial_destructor(T)) { + if (KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { pos = target; } else { while (pos > target) { @@ -427,7 +427,7 @@ class ArrayBuilder { } void clear() { - if (__has_trivial_destructor(T)) { + if (KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { pos = ptr; } else { while (pos > ptr) { @@ -442,7 +442,7 @@ class ArrayBuilder { T* target = ptr + size; if (target > pos) { // expand - if (__has_trivial_constructor(T)) { + if (KJ_HAS_TRIVIAL_CONSTRUCTOR(T)) { pos = target; } else { while (pos < target) { @@ -451,7 +451,7 @@ class ArrayBuilder { } } else { // truncate - if (__has_trivial_destructor(T)) { + if (KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { pos = target; } else { while (pos > target) { @@ -848,7 +848,7 @@ inline Array heapArray(std::initializer_list init) { return heapArray(init.begin(), init.end()); } -#if __cplusplus > 201402L +#if KJ_CPP_STD > 201402L template inline Array> arr(T&& param1, Params&&... params) { ArrayBuilder> builder = heapArrayBuilder>(sizeof...(params) + 1); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-coroutine-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-coroutine-test.c++ new file mode 100644 index 00000000000..de767eca0cb --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-coroutine-test.c++ @@ -0,0 +1,578 @@ +// Copyright (c) 2020 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include +#include +#include +#include +#include + +namespace kj { +namespace { + +#ifdef KJ_HAS_COROUTINE + +template +Promise> identity(T&& value) { + co_return kj::fwd(value); +} +// Work around a bonkers MSVC ICE with a separate overload. +Promise identity(const char* value) { + co_return value; +} + +KJ_TEST("Identity coroutine") { + EventLoop loop; + WaitScope waitScope(loop); + + KJ_EXPECT(identity(123).wait(waitScope) == 123); + KJ_EXPECT(*identity(kj::heap(456)).wait(waitScope) == 456); + + { + auto p = identity("we can cancel the coroutine"); + } +} + +template +Promise simpleCoroutine(kj::Promise result, kj::Promise dontThrow = true) { + KJ_ASSERT(co_await dontThrow); + co_return co_await result; +} + +KJ_TEST("Simple coroutine test") { + EventLoop loop; + WaitScope waitScope(loop); + + simpleCoroutine(kj::Promise(kj::READY_NOW)).wait(waitScope); + + KJ_EXPECT(simpleCoroutine(kj::Promise(123)).wait(waitScope) == 123); +} + +struct Counter { + size_t& wind; + size_t& unwind; + Counter(size_t& wind, size_t& unwind): wind(wind), unwind(unwind) { ++wind; } + ~Counter() { ++unwind; } + KJ_DISALLOW_COPY_AND_MOVE(Counter); +}; + +kj::Promise countAroundAwait(size_t& wind, size_t& unwind, kj::Promise promise) { + Counter counter1(wind, unwind); + co_await promise; + Counter counter2(wind, unwind); + co_return; +}; + +KJ_TEST("co_awaiting initial immediate promises suspends even if event loop is empty and running") { + // The coroutine PromiseNode implementation contains an optimization which allows us to avoid + // suspending the coroutine and instead immediately call PromiseNode::get() and proceed with + // execution, but only if the coroutine has suspended at least once. This test verifies that the + // optimization is disabled for this initial suspension. + + EventLoop loop; + WaitScope waitScope(loop); + + // The immediate-execution optimization is only enabled when the event loop is running, so use an + // eagerly-evaluated evalLater() to perform the test from within the event loop. (If we didn't + // eagerly-evaluate the promise, the result would be extracted after the loop finished, which + // would disable the optimization anyway.) + kj::evalLater([&]() { + size_t wind = 0, unwind = 0; + + auto promise = kj::Promise(kj::READY_NOW); + auto coroPromise = countAroundAwait(wind, unwind, kj::READY_NOW); + + // `coro` has not completed. + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); + }).eagerlyEvaluate(nullptr).wait(waitScope); + + kj::evalLater([&]() { + // If there are no background tasks in the queue, coroutines execute through an evalLater() + // without suspending. + + size_t wind = 0, unwind = 0; + bool evalLaterRan = false; + + auto promise = kj::evalLater([&]() { evalLaterRan = true; }); + auto coroPromise = countAroundAwait(wind, unwind, kj::mv(promise)); + + KJ_EXPECT(evalLaterRan == false); + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); + }).eagerlyEvaluate(nullptr).wait(waitScope); +} + +KJ_TEST("co_awaiting an immediate promise suspends if the event loop is not running") { + // We only want to enable the immediate-execution optimization if the event loop is running, or + // else a whole bunch of RPC tests break, because some .then()s get evaluated on promise + // construction, before any .wait() call. + + EventLoop loop; + WaitScope waitScope(loop); + + size_t wind = 0, unwind = 0; + + auto promise = kj::Promise(kj::READY_NOW); + auto coroPromise = countAroundAwait(wind, unwind, kj::READY_NOW); + + // In the previous test, this exact same code executed immediately because the event loop was + // running. + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); +} + +KJ_TEST("co_awaiting immediate promises suspends if the event loop is not empty") { + // We want to make sure that we can still return to the event loop when we need to. + + EventLoop loop; + WaitScope waitScope(loop); + + // The immediate-execution optimization is only enabled when the event loop is running, so use an + // eagerly-evaluated evalLater() to perform the test from within the event loop. (If we didn't + // eagerly-evaluate the promise, the result would be extracted after the loop finished.) + kj::evalLater([&]() { + size_t wind = 0, unwind = 0; + + // We need to enqueue an Event on the event loop to inhibit the immediate-execution + // optimization. Creating and then immediately fulfilling an EagerPromiseNode is a convenient + // way to do so. + auto paf = newPromiseAndFulfiller(); + paf.promise = paf.promise.eagerlyEvaluate(nullptr); + paf.fulfiller->fulfill(); + + auto promise = kj::Promise(kj::READY_NOW); + auto coroPromise = countAroundAwait(wind, unwind, kj::READY_NOW); + + // We didn't immediately extract the READY_NOW. + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); + }).eagerlyEvaluate(nullptr).wait(waitScope); + + kj::evalLater([&]() { + size_t wind = 0, unwind = 0; + bool evalLaterRan = false; + + // We need to enqueue an Event on the event loop to inhibit the immediate-execution + // optimization. Creating and then immediately fulfilling an EagerPromiseNode is a convenient + // way to do so. + auto paf = newPromiseAndFulfiller(); + paf.promise = paf.promise.eagerlyEvaluate(nullptr); + paf.fulfiller->fulfill(); + + auto promise = kj::evalLater([&]() { evalLaterRan = true; }); + auto coroPromise = countAroundAwait(wind, unwind, kj::mv(promise)); + + // We didn't continue through the evalLater() promise, because the background promise's + // continuation was next in the event loop's queue. + KJ_EXPECT(evalLaterRan == false); + // No Counter destructor has run. + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); + }).eagerlyEvaluate(nullptr).wait(waitScope); +} + +KJ_TEST("Exceptions propagate through layered coroutines") { + EventLoop loop; + WaitScope waitScope(loop); + + auto throwy = simpleCoroutine(kj::Promise(kj::NEVER_DONE), false); + + KJ_EXPECT_THROW_RECOVERABLE(FAILED, simpleCoroutine(kj::mv(throwy)).wait(waitScope)); +} + +KJ_TEST("Exceptions before the first co_await don't escape, but reject the promise") { + EventLoop loop; + WaitScope waitScope(loop); + + auto throwEarly = []() -> Promise { + KJ_FAIL_ASSERT("test exception"); +#ifdef __GNUC__ +// Yes, this `co_return` is unreachable. But without it, this function is no longer a coroutine. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunreachable-code" +#endif // __GNUC__ + co_return; +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif // __GNUC__ + }; + + auto throwy = throwEarly(); + + KJ_EXPECT_THROW_RECOVERABLE(FAILED, throwy.wait(waitScope)); +} + +KJ_TEST("Coroutines can catch exceptions from co_await") { + EventLoop loop; + WaitScope waitScope(loop); + + kj::String description; + + auto tryCatch = [&](kj::Promise promise) -> kj::Promise { + try { + co_await promise; + } catch (const kj::Exception& exception) { + co_return kj::str(exception.getDescription()); + } + KJ_FAIL_EXPECT("should have thrown"); + KJ_UNREACHABLE; + }; + + { + // Immediately ready case. + auto promise = kj::Promise(KJ_EXCEPTION(FAILED, "catch me")); + KJ_EXPECT(tryCatch(kj::mv(promise)).wait(waitScope) == "catch me"); + } + + { + // Ready later case. + auto promise = kj::evalLater([]() -> kj::Promise { + return KJ_EXCEPTION(FAILED, "catch me"); + }); + KJ_EXPECT(tryCatch(kj::mv(promise)).wait(waitScope) == "catch me"); + } +} + +KJ_TEST("Coroutines can be canceled while suspended") { + EventLoop loop; + WaitScope waitScope(loop); + + size_t wind = 0, unwind = 0; + + auto coro = [&](kj::Promise promise) -> kj::Promise { + Counter counter1(wind, unwind); + co_await kj::evalLater([](){}); + Counter counter2(wind, unwind); + co_await promise; + }; + + { + auto neverDone = kj::Promise(kj::NEVER_DONE); + neverDone = neverDone.attach(kj::heap(wind, unwind)); + auto promise = coro(kj::mv(neverDone)); + KJ_EXPECT(!promise.poll(waitScope)); + } + + // Stack variables on both sides of a co_await, plus coroutine arguments are destroyed. + KJ_EXPECT(wind == 3); + KJ_EXPECT(unwind == 3); +} + +kj::Promise deferredThrowCoroutine(kj::Promise awaitMe) { + KJ_DEFER(kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown during unwind"))); + co_await awaitMe; + co_return; +}; + +KJ_TEST("Exceptions during suspended coroutine frame-unwind propagate via destructor") { + EventLoop loop; + WaitScope waitScope(loop); + + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + deferredThrowCoroutine(kj::NEVER_DONE); + })); + + KJ_EXPECT(exception.getDescription() == "thrown during unwind"); +}; + +KJ_TEST("Exceptions during suspended coroutine frame-unwind do not cause a memory leak") { + EventLoop loop; + WaitScope waitScope(loop); + + // We can't easily test for memory leaks without hooking operator new and delete. However, we can + // arrange for the test to crash on failure, by having the coroutine suspend at a promise that we + // later fulfill, thus arming the Coroutine's Event. If we fail to destroy the coroutine in this + // state, EventLoop will throw on destruction because it can still see the Event in its list. + + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + auto paf = kj::newPromiseAndFulfiller(); + + auto coroPromise = deferredThrowCoroutine(kj::mv(paf.promise)); + + // Arm the Coroutine's Event. + paf.fulfiller->fulfill(); + + // If destroying `coroPromise` does not run ~Event(), then ~EventLoop() will crash later. + })); + + KJ_EXPECT(exception.getDescription() == "thrown during unwind"); +}; + +KJ_TEST("Exceptions during completed coroutine frame-unwind propagate via returned Promise") { + EventLoop loop; + WaitScope waitScope(loop); + + { + // First, prove that exceptions don't escape the destructor of a completed coroutine. + auto promise = deferredThrowCoroutine(kj::READY_NOW); + KJ_EXPECT(promise.poll(waitScope)); + } + + { + // Next, prove that they show up via the returned Promise. + auto promise = deferredThrowCoroutine(kj::READY_NOW); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("thrown during unwind", promise.wait(waitScope)); + } +} + +KJ_TEST("Coroutine destruction exceptions are ignored if there is another exception in flight") { + EventLoop loop; + WaitScope waitScope(loop); + + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + auto promise = deferredThrowCoroutine(kj::NEVER_DONE); + kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown before destroying throwy promise")); + })); + + KJ_EXPECT(exception.getDescription() == "thrown before destroying throwy promise"); +} + +KJ_TEST("co_await only sees coroutine destruction exceptions if promise was not rejected") { + EventLoop loop; + WaitScope waitScope(loop); + + // throwyDtorPromise is an immediate void promise that will throw when it's destroyed, which + // we expect to be able to catch from a coroutine which co_awaits it. + auto throwyDtorPromise = kj::Promise(kj::READY_NOW) + .attach(kj::defer([]() { + kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown during unwind")); + })); + + // rejectedThrowyDtorPromise is a rejected promise. When co_awaited in a coroutine, + // Awaiter::await_resume() will throw that exception for us to catch, but before we can catch it, + // the temporary promise will be destroyed. The exception it throws during unwind will be ignored, + // and the caller of the coroutine will see only the "thrown during execution" exception. + auto rejectedThrowyDtorPromise = kj::evalNow([&]() -> kj::Promise { + kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown during execution")); + }).attach(kj::defer([]() { + kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown during unwind")); + })); + + auto awaitPromise = [](kj::Promise promise) -> kj::Promise { + co_await promise; + }; + + KJ_EXPECT_THROW_MESSAGE("thrown during unwind", + awaitPromise(kj::mv(throwyDtorPromise)).wait(waitScope)); + + KJ_EXPECT_THROW_MESSAGE("thrown during execution", + awaitPromise(kj::mv(rejectedThrowyDtorPromise)).wait(waitScope)); +} + +#if !_MSC_VER && !__aarch64__ +uint countLines(StringPtr s) { + uint lines = 0; + for (char c: s) { + lines += c == '\n'; + } + return lines; +} + +// TODO(msvc): This test relies on GetFunctorStartAddress, which is not supported on MSVC currently, +// so skip the test. +// TODO(someday): Test is flakey on arm64, depending on how it's compiled. I haven't had a chance to +// investigate much, but noticed that it failed in a debug build, but passed in a local opt build. +KJ_TEST("Can trace through coroutines") { + // This verifies that async traces, generated either from promises or from events, can see through + // coroutines. + // + // This test may be a bit brittle because it depends on specific trace counts. + + // Enable stack traces, even in release mode. + class EnableFullStackTrace: public ExceptionCallback { + public: + StackTraceMode stackTraceMode() override { return StackTraceMode::FULL; } + }; + EnableFullStackTrace exceptionCallback; + + EventLoop loop; + WaitScope waitScope(loop); + + auto paf = newPromiseAndFulfiller(); + + // Get an async trace when the promise is fulfilled. We eagerlyEvaluate() to make sure the + // continuation executes while the event loop is running. + paf.promise = paf.promise.then([]() { + auto trace = getAsyncTrace(); + // We expect one entry for waitImpl(), one for the coroutine, and one for this continuation. + // When building in debug mode with CMake, I observed this count can be 2. The missing frame is + // probably this continuation. Let's just expect a range. + auto count = countLines(trace); + KJ_EXPECT(0 < count && count <= 3); + }).eagerlyEvaluate(nullptr); + + auto coroPromise = [&]() -> kj::Promise { + co_await paf.promise; + }(); + + { + auto trace = coroPromise.trace(); + // One for the Coroutine PromiseNode, one for paf.promise. + KJ_EXPECT(countLines(trace) >= 2); + } + + paf.fulfiller->fulfill(); + + coroPromise.wait(waitScope); +} +#endif // !_MSC_VER || defined(__clang__) + +Promise sendData(Promise> addressPromise) { + auto address = co_await addressPromise; + auto client = co_await address->connect(); + co_await client->write("foo", 3); +} + +Promise receiveDataCoroutine(Own listener) { + auto server = co_await listener->accept(); + char buffer[4]; + auto n = co_await server->read(buffer, 3, 4); + KJ_EXPECT(3u == n); + co_return heapString(buffer, n); +} + +KJ_TEST("Simple network test with coroutine") { + auto io = setupAsyncIo(); + auto& network = io.provider->getNetwork(); + + Own serverAddress = network.parseAddress("*", 0).wait(io.waitScope); + Own listener = serverAddress->listen(); + + sendData(network.parseAddress("localhost", listener->getPort())) + .detach([](Exception&& exception) { + KJ_FAIL_EXPECT(exception); + }); + + String result = receiveDataCoroutine(kj::mv(listener)).wait(io.waitScope); + + KJ_EXPECT("foo" == result); +} + +Promise> httpClientConnect(AsyncIoContext& io) { + auto addr = co_await io.provider->getNetwork().parseAddress("capnproto.org", 80); + co_return co_await addr->connect(); +} + +Promise httpClient(Own connection) { + // Borrowed and rewritten from compat/http-test.c++. + + HttpHeaderTable table; + auto client = newHttpClient(table, *connection); + + HttpHeaders headers(table); + headers.set(HttpHeaderId::HOST, "capnproto.org"); + + auto response = co_await client->request(HttpMethod::GET, "/", headers).response; + KJ_EXPECT(response.statusCode / 100 == 3); + auto location = KJ_ASSERT_NONNULL(response.headers->get(HttpHeaderId::LOCATION)); + KJ_EXPECT(location == "https://capnproto.org/"); + + auto body = co_await response.body->readAllText(); +} + +KJ_TEST("HttpClient to capnproto.org with a coroutine") { + auto io = setupAsyncIo(); + + auto promise = httpClientConnect(io).then([](Own connection) { + return httpClient(kj::mv(connection)); + }, [](Exception&&) { + KJ_LOG(WARNING, "skipping test because couldn't connect to capnproto.org"); + }); + + promise.wait(io.waitScope); +} + +// ======================================================================================= +// coCapture() tests + +KJ_TEST("Verify coCapture() functors can only be run once") { + auto io = kj::setupAsyncIo(); + + auto functor = coCapture([](kj::Timer& timer) -> kj::Promise { + co_await timer.afterDelay(1 * kj::MILLISECONDS); + }); + + auto promise = functor(io.lowLevelProvider->getTimer()); + KJ_EXPECT_THROW(FAILED, functor(io.lowLevelProvider->getTimer())); + + promise.wait(io.waitScope); +} + +auto makeDelayedIntegerFunctor(size_t i) { + return [i](kj::Timer& timer) -> kj::Promise { + co_await timer.afterDelay(1 * kj::MILLISECONDS); + co_return i; + }; +} + +KJ_TEST("Verify coCapture() with local scoped functors") { + auto io = kj::setupAsyncIo(); + + constexpr size_t COUNT = 100; + kj::Vector> promises; + for (size_t i = 0; i < COUNT; ++i) { + auto functor = coCapture(makeDelayedIntegerFunctor(i)); + promises.add(functor(io.lowLevelProvider->getTimer())); + } + + for (size_t i = COUNT; i > 0 ; --i) { + auto j = i-1; + auto result = promises[j].wait(io.waitScope); + KJ_REQUIRE(result == j); + } +} + +auto makeCheckThenDelayedIntegerFunctor(kj::Timer& timer, size_t i) { + return [&timer, i](size_t val) -> kj::Promise { + KJ_REQUIRE(val == i); + co_await timer.afterDelay(1 * kj::MILLISECONDS); + co_return i; + }; +} + +KJ_TEST("Verify coCapture() with continuation functors") { + // This test usually works locally without `coCapture()()`. It does however, fail in + // ASAN. + auto io = kj::setupAsyncIo(); + + constexpr size_t COUNT = 100; + kj::Vector> promises; + for (size_t i = 0; i < COUNT; ++i) { + auto promise = io.lowLevelProvider->getTimer().afterDelay(1 * kj::MILLISECONDS).then([i]() { + return i; + }); + promise = promise.then(coCapture( + makeCheckThenDelayedIntegerFunctor(io.lowLevelProvider->getTimer(), i))); + promises.add(kj::mv(promise)); + } + + for (size_t i = COUNT; i > 0 ; --i) { + auto j = i-1; + auto result = promises[j].wait(io.waitScope); + KJ_REQUIRE(result == j); + } +} + +#endif // KJ_HAS_COROUTINE + +} // namespace +} // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-inl.h b/libs/EXTERNAL/capnproto/c++/src/kj/async-inl.h index 55ab97a335a..fc8a9813244 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-inl.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-inl.h @@ -31,9 +31,13 @@ #include "async.h" // help IDE parse this file #endif -KJ_BEGIN_HEADER +#if _MSC_VER && KJ_HAS_COROUTINE +#include +#endif + +#include -#include "list.h" +KJ_BEGIN_HEADER namespace kj { namespace _ { // private @@ -134,15 +138,20 @@ class TraceBuilder { void** limit; }; -class Event { +struct alignas(void*) PromiseArena { + // Space in which a chain of promises may be allocated. See PromiseDisposer. + byte bytes[1024]; +}; + +class Event: private AsyncObject { // An event waiting to be executed. Not for direct use by applications -- promises use this // internally. public: - Event(); - Event(kj::EventLoop& loop); + Event(SourceLocation location); + Event(kj::EventLoop& loop, SourceLocation location); ~Event() noexcept(false); - KJ_DISALLOW_COPY(Event); + KJ_DISALLOW_COPY_AND_MOVE(Event); void armDepthFirst(); // Enqueue this event so that `fire()` will be called from the event loop soon. @@ -166,6 +175,18 @@ class Event { // Enqueues this event to happen after all other events have run to completion and there is // really nothing left to do except wait for I/O. + bool isNext(); + // True if the Event has been armed and is next in line to be fired. This can be used after + // calling PromiseNode::onReady(event) to determine if a promise being waited is immediately + // ready, in which case continuations may be optimistically run without returning to the event + // loop. Note that this optimization is only valid if we know that we would otherwise immediately + // return to the event loop without running more application code. So this turns out to be useful + // in fairly narrow circumstances, chiefly when a coroutine is about to suspend, but discovers it + // doesn't need to. + // + // Returns false if the event loop is not currently running. This ensures that promise + // continuations don't execute except under a call to .wait(). + void disarm(); // If the event is armed but hasn't fired, cancel it. (Destroying the event does this // implicitly.) @@ -194,9 +215,46 @@ class Event { Event* next; Event** prev; bool firing = false; + + static constexpr uint MAGIC_LIVE_VALUE = 0x1e366381u; + uint live = MAGIC_LIVE_VALUE; + SourceLocation location; +}; + +class PromiseArenaMember { + // An object that is allocated in a PromiseArena. `PromiseNode` inherits this, and most + // arena-allocated objects are `PromiseNode` subclasses, but `TaskSet::Task`, ForkHub, and + // potentially other objects that commonly live on the end of a promise chain can also leverage + // this. + +public: + virtual void destroy() = 0; + // Destroys and frees the node. + // + // If the node was allocated using allocPromise(), then destroy() must call + // freePromise(this). If it was allocated some other way, then it is `destroy()`'s + // responsibility to complete any necessary cleanup of memory, e.g. call `delete this`. + // + // We use this instead of a virtual destructor for two reasons: + // 1. Coroutine nodes are not independent objects, they have to call destroy() on the coroutine + // handle to delete themselves. + // 2. XThreadEvents sometimes leave it up to a different thread to actually delete the object. + +private: + PromiseArena* arena = nullptr; + // If non-null, then this PromiseNode is the last node allocated within the given arena, and + // therefore owns the arena. After this node is destroyed, the arena should be deleted. + // + // PromiseNodes are allocated within the arena starting from the end, and `PromiseNode`s + // allocated this way are required to have `PromiseNode` itself as their leftmost inherited type, + // so that the pointers match. Thus, the space in `arena` from its start to the location of the + // `PromiseNode` is known to be available for subsequent allocations (which should then take + // ownership of the arena). + + friend class PromiseDisposer; }; -class PromiseNode { +class PromiseNode: public PromiseArenaMember, private AsyncObject { // A Promise contains a chain of PromiseNodes tracking the pending transformations. // // To reduce generated code bloat, PromiseNode is not a template. Instead, it makes very hacky @@ -212,7 +270,7 @@ class PromiseNode { // never be armed, only the new one. If called again after the event was armed, the new event // will be armed immediately. Can be called with nullptr to un-register the existing event. - virtual void setSelfPointer(Own* selfPtr) noexcept; + virtual void setSelfPointer(OwnPromiseNode* selfPtr) noexcept; // Tells the node that `selfPtr` is the pointer that owns this node, and will continue to own // this node until it is destroyed or setSelfPointer() is called again. ChainPromiseNode uses // this to shorten redundant chains. The default implementation does nothing; only @@ -231,7 +289,7 @@ class PromiseNode { // // If `stopAtNextEvent` is true, then the trace should stop as soon as it hits a PromiseNode that // also implements Event, and should not trace that node or its children. This is used in - // conjuction with Event::traceEvent(). The chain of Events is often more sparse than the chain + // conjunction with Event::traceEvent(). The chain of Events is often more sparse than the chain // of PromiseNodes, because a TransformPromiseNode (which implements .then()) is not itself an // Event. TransformPromiseNode instead tells its child node to directly notify its *parent* node // when it is ready, and then TransformPromiseNode applies the .then() transformation during the @@ -246,7 +304,7 @@ class PromiseNode { // must not allocate nor take locks. template - static Own from(T&& promise) { + static OwnPromiseNode from(T&& promise) { // Given a Promise, extract the PromiseNode. return kj::mv(promise.node); } @@ -256,7 +314,7 @@ class PromiseNode { return *promise.node; } template - static T to(Own&& node) { + static T to(OwnPromiseNode&& node) { // Construct a Promise from a PromiseNode. (T should be a Promise type.) return T(false, kj::mv(node)); } @@ -282,8 +340,126 @@ class PromiseNode { }; }; +class PromiseDisposer { +public: + template + static constexpr bool canArenaAllocate() { + // We can only use arena allocation for types that fit in an arena and have pointer-size + // alignment. Anything else will need to be allocated as a separate heap object. + return sizeof(T) <= sizeof(PromiseArena) && alignof(T) <= alignof(void*); + } + + static void dispose(PromiseArenaMember* node) { + PromiseArena* arena = node->arena; + node->destroy(); + delete arena; // reminder: `delete` automatically ignores null pointers + } + + template + static kj::Own alloc(Params&&... params) noexcept { + // Implements allocPromise(). + T* ptr; + if (!canArenaAllocate()) { + // Node too big (or needs weird alignment), fall back to regular heap allocation. + ptr = new T(kj::fwd(params)...); + } else { + // Start a new arena. + // + // NOTE: As in append() (below), we don't implement exception-safety because it causes code + // bloat and these constructors probably don't throw. Instead this function is noexcept, so + // if a constructor does throw, it'll crash rather than leak memory. + auto* arena = new PromiseArena; + ptr = reinterpret_cast(arena + 1) - 1; + ctor(*ptr, kj::fwd(params)...); + ptr->arena = arena; + KJ_IREQUIRE(reinterpret_cast(ptr) == + reinterpret_cast(static_cast(ptr)), + "PromiseArenaMember must be the leftmost inherited type."); + } + return kj::Own(ptr); + } + + template + static kj::Own append( + OwnPromiseNode&& next, Params&&... params) noexcept { + // Implements appendPromise(). + + PromiseArena* arena = next->arena; + + if (!canArenaAllocate() || arena == nullptr || + reinterpret_cast(next.get()) - reinterpret_cast(arena) < sizeof(T)) { + // No arena available, or not enough space, or weird alignment needed. Start new arena. + return alloc(kj::mv(next), kj::fwd(params)...); + } else { + // Append to arena. + // + // NOTE: When we call ctor(), it takes ownership of `next`, so we shouldn't assume `next` + // still exists after it returns. So we have to remove ownership of the arena before that. + // In theory if we wanted this to be exception-safe, we'd also have to arrange to delete + // the arena if the constructor throws. However, in practice none of the PromiseNode + // constructors throw, so we just mark the whole method noexcept in order to avoid the + // code bloat to handle this case. + next->arena = nullptr; + T* ptr = reinterpret_cast(next.get()) - 1; + ctor(*ptr, kj::mv(next), kj::fwd(params)...); + ptr->arena = arena; + KJ_IREQUIRE(reinterpret_cast(ptr) == + reinterpret_cast(static_cast(ptr)), + "PromiseArenaMember must be the leftmost inherited type."); + return kj::Own(ptr); + } + } +}; + +template +static kj::Own allocPromise(Params&&... params) { + // Allocate a PromiseNode without appending it to any existing promise arena. Space for a new + // arena will be allocated. + return PromiseDisposer::alloc(kj::fwd(params)...); +} + +template ()> +struct FreePromiseNode; +template +struct FreePromiseNode { + static inline void free(T* ptr) { + // The object will have been allocated in an arena, so we only want to run the destructor. + // The arena's memory will be freed separately. + kj::dtor(*ptr); + } +}; +template +struct FreePromiseNode { + static inline void free(T* ptr) { + // The object will have been allocated separately on the heap. + return delete ptr; + } +}; + +template +static void freePromise(T* ptr) { + // Free a PromiseNode originally allocated using `allocPromise()`. The implementation of + // PromiseNode::destroy() must call this for any type that is allocated using allocPromise(). + FreePromiseNode::free(ptr); +} + +template +static kj::Own appendPromise(OwnPromiseNode&& next, Params&&... params) { + // Append a promise to the arena that currently ends with `next`. `next` is also still passed as + // the first parameter to the new object's constructor. + // + // This is semantically the same as `allocPromise()` except that it may avoid the underlying + // memory allocation. `next` must end up being destroyed before the new object (i.e. the new + // object must never transfer away ownership of `next`). + return PromiseDisposer::append(kj::mv(next), kj::fwd(params)...); +} + // ------------------------------------------------------------------- +inline ReadyNow::operator Promise() const { + return PromiseNode::to>(readyNow()); +} + template inline NeverDone::operator Promise() const { return PromiseNode::to>(neverDone()); @@ -306,6 +482,7 @@ class ImmediatePromiseNode final: public ImmediatePromiseNodeBase { public: ImmediatePromiseNode(ExceptionOr&& result): result(kj::mv(result)) {} + void destroy() override { freePromise(this); } void get(ExceptionOrValue& output) noexcept override { output.as() = kj::mv(result); @@ -318,6 +495,7 @@ class ImmediatePromiseNode final: public ImmediatePromiseNodeBase { class ImmediateBrokenPromiseNode final: public ImmediatePromiseNodeBase { public: ImmediateBrokenPromiseNode(Exception&& exception); + void destroy() override; void get(ExceptionOrValue& output) noexcept override; @@ -325,18 +503,27 @@ class ImmediateBrokenPromiseNode final: public ImmediatePromiseNodeBase { Exception exception; }; +template +class ConstPromiseNode: public ImmediatePromiseNodeBase { +public: + void destroy() override {} + void get(ExceptionOrValue& output) noexcept override { + output.as() = value; + } +}; + // ------------------------------------------------------------------- class AttachmentPromiseNodeBase: public PromiseNode { public: - AttachmentPromiseNodeBase(Own&& dependency); + AttachmentPromiseNodeBase(OwnPromiseNode&& dependency); void onReady(Event* event) noexcept override; void get(ExceptionOrValue& output) noexcept override; void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; private: - Own dependency; + OwnPromiseNode dependency; void dropDependency(); @@ -350,9 +537,10 @@ class AttachmentPromiseNode final: public AttachmentPromiseNodeBase { // object) until the promise resolves. public: - AttachmentPromiseNode(Own&& dependency, Attachment&& attachment) + AttachmentPromiseNode(OwnPromiseNode&& dependency, Attachment&& attachment) : AttachmentPromiseNodeBase(kj::mv(dependency)), attachment(kj::mv(attachment)) {} + void destroy() override { freePromise(this); } ~AttachmentPromiseNode() noexcept(false) { // We need to make sure the dependency is deleted before we delete the attachment because the @@ -498,14 +686,14 @@ struct GetFunctorStartAddress: public GetFunctorStartAddress<> {}; class TransformPromiseNodeBase: public PromiseNode { public: - TransformPromiseNodeBase(Own&& dependency, void* continuationTracePtr); + TransformPromiseNodeBase(OwnPromiseNode&& dependency, void* continuationTracePtr); void onReady(Event* event) noexcept override; void get(ExceptionOrValue& output) noexcept override; void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; private: - Own dependency; + OwnPromiseNode dependency; void* continuationTracePtr; void dropDependency(); @@ -523,10 +711,11 @@ class TransformPromiseNode final: public TransformPromiseNodeBase { // function (implements `then()`). public: - TransformPromiseNode(Own&& dependency, Func&& func, ErrorFunc&& errorHandler, + TransformPromiseNode(OwnPromiseNode&& dependency, Func&& func, ErrorFunc&& errorHandler, void* continuationTracePtr) : TransformPromiseNodeBase(kj::mv(dependency), continuationTracePtr), func(kj::fwd(func)), errorHandler(kj::fwd(errorHandler)) {} + void destroy() override { freePromise(this); } ~TransformPromiseNode() noexcept(false) { // We need to make sure the dependency is deleted before we delete the continuations because it @@ -562,10 +751,11 @@ class TransformPromiseNode final: public TransformPromiseNodeBase { // ------------------------------------------------------------------- class ForkHubBase; +using OwnForkHubBase = Own; class ForkBranchBase: public PromiseNode { public: - ForkBranchBase(Own&& hub); + ForkBranchBase(OwnForkHubBase&& hub); ~ForkBranchBase() noexcept(false); void hubReady() noexcept; @@ -584,7 +774,7 @@ class ForkBranchBase: public PromiseNode { private: OnReadyEvent onReadyEvent; - Own hub; + OwnForkHubBase hub; ForkBranchBase* next = nullptr; ForkBranchBase** prevPtr = nullptr; @@ -605,7 +795,8 @@ class ForkBranch final: public ForkBranchBase { // a const reference. public: - ForkBranch(Own&& hub): ForkBranchBase(kj::mv(hub)) {} + ForkBranch(OwnForkHubBase&& hub): ForkBranchBase(kj::mv(hub)) {} + void destroy() override { freePromise(this); } void get(ExceptionOrValue& output) noexcept override { ExceptionOr& hubResult = getHubResultRef().template as(); @@ -625,7 +816,8 @@ class SplitBranch final: public ForkBranchBase { // a const reference. public: - SplitBranch(Own&& hub): ForkBranchBase(kj::mv(hub)) {} + SplitBranch(OwnForkHubBase&& hub): ForkBranchBase(kj::mv(hub)) {} + void destroy() override { freePromise(this); } typedef kj::Decay(kj::instance()))> Element; @@ -643,14 +835,31 @@ class SplitBranch final: public ForkBranchBase { // ------------------------------------------------------------------- -class ForkHubBase: public Refcounted, protected Event { +class ForkHubBase: public PromiseArenaMember, protected Event { public: - ForkHubBase(Own&& inner, ExceptionOrValue& resultRef); + ForkHubBase(OwnPromiseNode&& inner, ExceptionOrValue& resultRef, SourceLocation location); inline ExceptionOrValue& getResultRef() { return resultRef; } + inline bool isShared() const { return refcount > 1; } + + Own addRef() { + ++refcount; + return Own(this); + } + + static void dispose(ForkHubBase* obj) { + if (--obj->refcount == 0) { + PromiseDisposer::dispose(obj); + } + } + private: - Own inner; + uint refcount = 1; + // We manually implement refcounting for ForkHubBase so that we can use it together with + // PromiseDisposer's arena allocation. + + OwnPromiseNode inner; ExceptionOrValue& resultRef; ForkBranchBase* headBranch = nullptr; @@ -670,29 +879,33 @@ class ForkHub final: public ForkHubBase { // possible). public: - ForkHub(Own&& inner): ForkHubBase(kj::mv(inner), result) {} + ForkHub(OwnPromiseNode&& inner, SourceLocation location) + : ForkHubBase(kj::mv(inner), result, location) {} + void destroy() override { freePromise(this); } Promise<_::UnfixVoid> addBranch() { - return _::PromiseNode::to>>(kj::heap>(addRef(*this))); + return _::PromiseNode::to>>( + allocPromise>(addRef())); } - _::SplitTuplePromise split() { - return splitImpl(MakeIndexes()>()); + _::SplitTuplePromise split(SourceLocation location) { + return splitImpl(MakeIndexes()>(), location); } private: ExceptionOr result; template - _::SplitTuplePromise splitImpl(Indexes) { - return kj::tuple(addSplit()...); + _::SplitTuplePromise splitImpl(Indexes, SourceLocation location) { + return kj::tuple(addSplit(location)...); } template - ReducePromises::Element> addSplit() { + ReducePromises::Element> addSplit(SourceLocation location) { return _::PromiseNode::to::Element>>( - maybeChain(kj::heap>(addRef(*this)), - implicitCast::Element*>(nullptr))); + maybeChain(allocPromise>(addRef()), + implicitCast::Element*>(nullptr), + location)); } }; @@ -709,11 +922,12 @@ class ChainPromiseNode final: public PromiseNode, public Event { // Own. Ugh, templates and private... public: - explicit ChainPromiseNode(Own inner); + explicit ChainPromiseNode(OwnPromiseNode inner, SourceLocation location); ~ChainPromiseNode() noexcept(false); + void destroy() override; void onReady(Event* event) noexcept override; - void setSelfPointer(Own* selfPtr) noexcept override; + void setSelfPointer(OwnPromiseNode* selfPtr) noexcept override; void get(ExceptionOrValue& output) noexcept override; void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; @@ -725,24 +939,24 @@ class ChainPromiseNode final: public PromiseNode, public Event { State state; - Own inner; + OwnPromiseNode inner; // In STEP1, a PromiseNode for a Promise. // In STEP2, a PromiseNode for a T. Event* onReadyEvent = nullptr; - Own* selfPtr = nullptr; + OwnPromiseNode* selfPtr = nullptr; Maybe> fire() override; void traceEvent(TraceBuilder& builder) override; }; template -Own maybeChain(Own&& node, Promise*) { - return heap(kj::mv(node)); +OwnPromiseNode maybeChain(OwnPromiseNode&& node, Promise*, SourceLocation location) { + return appendPromise(kj::mv(node), location); } template -Own&& maybeChain(Own&& node, T*) { +OwnPromiseNode&& maybeChain(OwnPromiseNode&& node, T*, SourceLocation location) { return kj::mv(node); } @@ -760,8 +974,9 @@ inline Promise maybeReduce(Promise&& promise, ...) { class ExclusiveJoinPromiseNode final: public PromiseNode { public: - ExclusiveJoinPromiseNode(Own left, Own right); + ExclusiveJoinPromiseNode(OwnPromiseNode left, OwnPromiseNode right, SourceLocation location); ~ExclusiveJoinPromiseNode() noexcept(false); + void destroy() override; void onReady(Event* event) noexcept override; void get(ExceptionOrValue& output) noexcept override; @@ -770,7 +985,8 @@ class ExclusiveJoinPromiseNode final: public PromiseNode { private: class Branch: public Event { public: - Branch(ExclusiveJoinPromiseNode& joinNode, Own dependency); + Branch(ExclusiveJoinPromiseNode& joinNode, OwnPromiseNode dependency, + SourceLocation location); ~Branch() noexcept(false); bool get(ExceptionOrValue& output); @@ -781,7 +997,7 @@ class ExclusiveJoinPromiseNode final: public PromiseNode { private: ExclusiveJoinPromiseNode& joinNode; - Own dependency; + OwnPromiseNode dependency; friend class ExclusiveJoinPromiseNode; }; @@ -793,10 +1009,17 @@ class ExclusiveJoinPromiseNode final: public PromiseNode { // ------------------------------------------------------------------- +enum class ArrayJoinBehavior { + LAZY, + EAGER, +}; + class ArrayJoinPromiseNodeBase: public PromiseNode { public: - ArrayJoinPromiseNodeBase(Array> promises, - ExceptionOrValue* resultParts, size_t partSize); + ArrayJoinPromiseNodeBase(Array promises, + ExceptionOrValue* resultParts, size_t partSize, + SourceLocation location, + ArrayJoinBehavior joinBehavior); ~ArrayJoinPromiseNodeBase() noexcept(false); void onReady(Event* event) noexcept override final; @@ -808,24 +1031,24 @@ class ArrayJoinPromiseNodeBase: public PromiseNode { // Called to compile the result only in the case where there were no errors. private: + const ArrayJoinBehavior joinBehavior; + uint countLeft; OnReadyEvent onReadyEvent; + bool armed = false; class Branch final: public Event { public: - Branch(ArrayJoinPromiseNodeBase& joinNode, Own dependency, - ExceptionOrValue& output); + Branch(ArrayJoinPromiseNodeBase& joinNode, OwnPromiseNode dependency, + ExceptionOrValue& output, SourceLocation location); ~Branch() noexcept(false); Maybe> fire() override; void traceEvent(TraceBuilder& builder) override; - Maybe getPart(); - // Calls dependency->get(output). If there was an exception, return it. - private: ArrayJoinPromiseNodeBase& joinNode; - Own dependency; + OwnPromiseNode dependency; ExceptionOrValue& output; friend class ArrayJoinPromiseNodeBase; @@ -837,10 +1060,14 @@ class ArrayJoinPromiseNodeBase: public PromiseNode { template class ArrayJoinPromiseNode final: public ArrayJoinPromiseNodeBase { public: - ArrayJoinPromiseNode(Array> promises, - Array> resultParts) - : ArrayJoinPromiseNodeBase(kj::mv(promises), resultParts.begin(), sizeof(ExceptionOr)), + ArrayJoinPromiseNode(Array promises, + Array> resultParts, + SourceLocation location, + ArrayJoinBehavior joinBehavior) + : ArrayJoinPromiseNodeBase(kj::mv(promises), resultParts.begin(), sizeof(ExceptionOr), + location, joinBehavior), resultParts(kj::mv(resultParts)) {} + void destroy() override { freePromise(this); } protected: void getNoError(ExceptionOrValue& output) noexcept override { @@ -860,9 +1087,12 @@ class ArrayJoinPromiseNode final: public ArrayJoinPromiseNodeBase { template <> class ArrayJoinPromiseNode final: public ArrayJoinPromiseNodeBase { public: - ArrayJoinPromiseNode(Array> promises, - Array> resultParts); + ArrayJoinPromiseNode(Array promises, + Array> resultParts, + SourceLocation location, + ArrayJoinBehavior joinBehavior); ~ArrayJoinPromiseNode(); + void destroy() override; protected: void getNoError(ExceptionOrValue& output) noexcept override; @@ -878,13 +1108,14 @@ class EagerPromiseNodeBase: public PromiseNode, protected Event { // evaluate it. public: - EagerPromiseNodeBase(Own&& dependency, ExceptionOrValue& resultRef); + EagerPromiseNodeBase(OwnPromiseNode&& dependency, ExceptionOrValue& resultRef, + SourceLocation location); void onReady(Event* event) noexcept override; void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; private: - Own dependency; + OwnPromiseNode dependency; OnReadyEvent onReadyEvent; ExceptionOrValue& resultRef; @@ -896,8 +1127,9 @@ class EagerPromiseNodeBase: public PromiseNode, protected Event { template class EagerPromiseNode final: public EagerPromiseNodeBase { public: - EagerPromiseNode(Own&& dependency) - : EagerPromiseNodeBase(kj::mv(dependency), result) {} + EagerPromiseNode(OwnPromiseNode&& dependency, SourceLocation location) + : EagerPromiseNodeBase(kj::mv(dependency), result, location) {} + void destroy() override { freePromise(this); } void get(ExceptionOrValue& output) noexcept override { output.as() = kj::mv(result); @@ -908,10 +1140,10 @@ class EagerPromiseNode final: public EagerPromiseNodeBase { }; template -Own spark(Own&& node) { +OwnPromiseNode spark(OwnPromiseNode&& node, SourceLocation location) { // Forces evaluation of the given node to begin as soon as possible, even if no one is waiting // on it. - return heap>(kj::mv(node)); + return appendPromise>(kj::mv(node), location); } // ------------------------------------------------------------------- @@ -939,6 +1171,7 @@ class AdapterPromiseNode final: public AdapterPromiseNodeBase, template AdapterPromiseNode(Params&&... params) : adapter(static_cast>&>(*this), kj::fwd(params)...) {} + void destroy() override { freePromise(this); } void get(ExceptionOrValue& output) noexcept override { KJ_IREQUIRE(!isWaiting()); @@ -977,8 +1210,8 @@ class FiberBase: public PromiseNode, private Event { // Base class for the outer PromiseNode representing a fiber. public: - explicit FiberBase(size_t stackSize, _::ExceptionOrValue& result); - explicit FiberBase(const FiberPool& pool, _::ExceptionOrValue& result); + explicit FiberBase(size_t stackSize, _::ExceptionOrValue& result, SourceLocation location); + explicit FiberBase(const FiberPool& pool, _::ExceptionOrValue& result, SourceLocation location); ~FiberBase() noexcept(false); void start() { armDepthFirst(); } @@ -991,7 +1224,7 @@ class FiberBase: public PromiseNode, private Event { protected: bool isFinished() { return state == FINISHED; } - void destroy(); + void cancel(); private: enum { WAITING, RUNNING, CANCELED, FINISHED } state; @@ -1009,17 +1242,20 @@ class FiberBase: public PromiseNode, private Event { // Implements Event. Each time the event is fired, switchToFiber() is called. friend class FiberStack; - friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, - WaitScope& waitScope); - friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope); + friend void _::waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, + WaitScope& waitScope, SourceLocation location); + friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location); }; template class Fiber final: public FiberBase { public: - explicit Fiber(size_t stackSize, Func&& func): FiberBase(stackSize, result), func(kj::fwd(func)) {} - explicit Fiber(const FiberPool& pool, Func&& func): FiberBase(pool, result), func(kj::fwd(func)) {} - ~Fiber() noexcept(false) { destroy(); } + explicit Fiber(size_t stackSize, Func&& func, SourceLocation location) + : FiberBase(stackSize, result, location), func(kj::fwd(func)) {} + explicit Fiber(const FiberPool& pool, Func&& func, SourceLocation location) + : FiberBase(pool, result, location), func(kj::fwd(func)) {} + ~Fiber() noexcept(false) { cancel(); } + void destroy() override { freePromise(this); } typedef FixVoid()(kj::instance()))> ResultType; @@ -1044,24 +1280,25 @@ class Fiber final: public FiberBase { template Promise::Promise(_::FixVoid value) - : PromiseBase(heap<_::ImmediatePromiseNode<_::FixVoid>>(kj::mv(value))) {} + : PromiseBase(_::allocPromise<_::ImmediatePromiseNode<_::FixVoid>>(kj::mv(value))) {} template Promise::Promise(kj::Exception&& exception) - : PromiseBase(heap<_::ImmediateBrokenPromiseNode>(kj::mv(exception))) {} + : PromiseBase(_::allocPromise<_::ImmediateBrokenPromiseNode>(kj::mv(exception))) {} template template -PromiseForResult Promise::then(Func&& func, ErrorFunc&& errorHandler) { +PromiseForResult Promise::then(Func&& func, ErrorFunc&& errorHandler, + SourceLocation location) { typedef _::FixVoid<_::ReturnType> ResultT; void* continuationTracePtr = _::GetFunctorStartAddress<_::FixVoid&&>::apply(func); - Own<_::PromiseNode> intermediate = - heap<_::TransformPromiseNode, Func, ErrorFunc>>( + _::OwnPromiseNode intermediate = + _::appendPromise<_::TransformPromiseNode, Func, ErrorFunc>>( kj::mv(node), kj::fwd(func), kj::fwd(errorHandler), continuationTracePtr); auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType>>( - _::maybeChain(kj::mv(intermediate), implicitCast(nullptr))); + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); return _::maybeReduce(kj::mv(result), false); } @@ -1095,7 +1332,7 @@ struct IdentityFunc> { template template -Promise Promise::catch_(ErrorFunc&& errorHandler) { +Promise Promise::catch_(ErrorFunc&& errorHandler, SourceLocation location) { // then()'s ErrorFunc can only return a Promise if Func also returns a Promise. In this case, // Func is being filled in automatically. We want to make sure ErrorFunc can return a Promise, // but we don't want the extra overhead of promise chaining if ErrorFunc doesn't actually @@ -1106,29 +1343,30 @@ Promise Promise::catch_(ErrorFunc&& errorHandler) { // The reason catch_() isn't simply implemented in terms of then() is because we want the trace // pointer to be based on ErrorFunc rather than Func. void* continuationTracePtr = _::GetFunctorStartAddress::apply(errorHandler); - Own<_::PromiseNode> intermediate = - heap<_::TransformPromiseNode, Func, ErrorFunc>>( + _::OwnPromiseNode intermediate = + _::appendPromise<_::TransformPromiseNode, Func, ErrorFunc>>( kj::mv(node), Func(), kj::fwd(errorHandler), continuationTracePtr); auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType>>( - _::maybeChain(kj::mv(intermediate), implicitCast(nullptr))); + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); return _::maybeReduce(kj::mv(result), false); } template -T Promise::wait(WaitScope& waitScope) { +T Promise::wait(WaitScope& waitScope, SourceLocation location) { _::ExceptionOr<_::FixVoid> result; - _::waitImpl(kj::mv(node), result, waitScope); + _::waitImpl(kj::mv(node), result, waitScope, location); return convertToReturn(kj::mv(result)); } template -bool Promise::poll(WaitScope& waitScope) { - return _::pollImpl(*node, waitScope); +bool Promise::poll(WaitScope& waitScope, SourceLocation location) { + return _::pollImpl(*node, waitScope, location); } template -ForkedPromise Promise::fork() { - return ForkedPromise(false, refcounted<_::ForkHub<_::FixVoid>>(kj::mv(node))); +ForkedPromise Promise::fork(SourceLocation location) { + return ForkedPromise(false, + _::PromiseDisposer::alloc<_::ForkHub<_::FixVoid>, _::ForkHubBase>(kj::mv(node), location)); } template @@ -1142,34 +1380,36 @@ bool ForkedPromise::hasBranches() { } template -_::SplitTuplePromise Promise::split() { - return refcounted<_::ForkHub<_::FixVoid>>(kj::mv(node))->split(); +_::SplitTuplePromise Promise::split(SourceLocation location) { + return _::PromiseDisposer::alloc<_::ForkHub<_::FixVoid>, _::ForkHubBase>( + kj::mv(node), location)->split(location); } template -Promise Promise::exclusiveJoin(Promise&& other) { - return Promise(false, heap<_::ExclusiveJoinPromiseNode>(kj::mv(node), kj::mv(other.node))); +Promise Promise::exclusiveJoin(Promise&& other, SourceLocation location) { + return Promise(false, _::appendPromise<_::ExclusiveJoinPromiseNode>( + kj::mv(node), kj::mv(other.node), location)); } template template Promise Promise::attach(Attachments&&... attachments) { - return Promise(false, kj::heap<_::AttachmentPromiseNode>>( + return Promise(false, _::appendPromise<_::AttachmentPromiseNode>>( kj::mv(node), kj::tuple(kj::fwd(attachments)...))); } template template -Promise Promise::eagerlyEvaluate(ErrorFunc&& errorHandler) { +Promise Promise::eagerlyEvaluate(ErrorFunc&& errorHandler, SourceLocation location) { // See catch_() for commentary. return Promise(false, _::spark<_::FixVoid>(then( _::IdentityFunc()))>(), - kj::fwd(errorHandler)).node)); + kj::fwd(errorHandler)).node, location)); } template -Promise Promise::eagerlyEvaluate(decltype(nullptr)) { - return Promise(false, _::spark<_::FixVoid>(kj::mv(node))); +Promise Promise::eagerlyEvaluate(decltype(nullptr), SourceLocation location) { + return Promise(false, _::spark<_::FixVoid>(kj::mv(node), location)); } template @@ -1177,6 +1417,12 @@ kj::String Promise::trace() { return PromiseBase::trace(); } +template +inline Promise constPromise() { + static _::ConstPromiseNode NODE; + return _::PromiseNode::to>(_::OwnPromiseNode(&NODE)); +} + template inline PromiseForResult evalLater(Func&& func) { return _::yield().then(kj::fwd(func), _::PropagateException()); @@ -1235,24 +1481,28 @@ inline PromiseForResult retryOnDisconnect(Func&& func) { } template -inline PromiseForResult startFiber(size_t stackSize, Func&& func) { +inline PromiseForResult startFiber( + size_t stackSize, Func&& func, SourceLocation location) { typedef _::FixVoid<_::ReturnType> ResultT; - Own<_::FiberBase> intermediate = kj::heap<_::Fiber>(stackSize, kj::fwd(func)); + auto intermediate = _::allocPromise<_::Fiber>( + stackSize, kj::fwd(func), location); intermediate->start(); auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType>>( - _::maybeChain(kj::mv(intermediate), implicitCast(nullptr))); + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); return _::maybeReduce(kj::mv(result), false); } template -inline PromiseForResult FiberPool::startFiber(Func&& func) const { +inline PromiseForResult FiberPool::startFiber( + Func&& func, SourceLocation location) const { typedef _::FixVoid<_::ReturnType> ResultT; - Own<_::FiberBase> intermediate = kj::heap<_::Fiber>(*this, kj::fwd(func)); + auto intermediate = _::allocPromise<_::Fiber>( + *this, kj::fwd(func), location); intermediate->start(); auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType>>( - _::maybeChain(kj::mv(intermediate), implicitCast(nullptr))); + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); return _::maybeReduce(kj::mv(result), false); } @@ -1269,10 +1519,19 @@ void Promise::detach(ErrorFunc&& errorHandler) { } template -Promise> joinPromises(Array>&& promises) { - return _::PromiseNode::to>>(kj::heap<_::ArrayJoinPromiseNode>( +Promise> joinPromises(Array>&& promises, SourceLocation location) { + return _::PromiseNode::to>>(_::allocPromise<_::ArrayJoinPromiseNode>( + KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); }, + heapArray<_::ExceptionOr>(promises.size()), location, + _::ArrayJoinBehavior::LAZY)); +} + +template +Promise> joinPromisesFailFast(Array>&& promises, SourceLocation location) { + return _::PromiseNode::to>>(_::allocPromise<_::ArrayJoinPromiseNode>( KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); }, - heapArray<_::ExceptionOr>(promises.size()))); + heapArray<_::ExceptionOr>(promises.size()), location, + _::ArrayJoinBehavior::EAGER)); } // ======================================================================================= @@ -1315,7 +1574,7 @@ class WeakFulfiller final: public PromiseFulfiller, public WeakFulfillerBase // fulfiller and detach() is called when the promise is destroyed. public: - KJ_DISALLOW_COPY(WeakFulfiller); + KJ_DISALLOW_COPY_AND_MOVE(WeakFulfiller); static kj::Own make() { WeakFulfiller* ptr = new WeakFulfiller; @@ -1399,21 +1658,23 @@ bool PromiseFulfiller::rejectIfThrows(Func&& func) { template _::ReducePromises newAdaptedPromise(Params&&... adapterConstructorParams) { - Own<_::PromiseNode> intermediate( - heap<_::AdapterPromiseNode<_::FixVoid, Adapter>>( + _::OwnPromiseNode intermediate( + _::allocPromise<_::AdapterPromiseNode<_::FixVoid, Adapter>>( kj::fwd(adapterConstructorParams)...)); + // We can't capture SourceLocation in this function's arguments since it is a vararg template. :( return _::PromiseNode::to<_::ReducePromises>( - _::maybeChain(kj::mv(intermediate), implicitCast(nullptr))); + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), SourceLocation())); } template -PromiseFulfillerPair newPromiseAndFulfiller() { +PromiseFulfillerPair newPromiseAndFulfiller(SourceLocation location) { auto wrapper = _::WeakFulfiller::make(); - Own<_::PromiseNode> intermediate( - heap<_::AdapterPromiseNode<_::FixVoid, _::PromiseAndFulfillerAdapter>>(*wrapper)); + _::OwnPromiseNode intermediate( + _::allocPromise<_::AdapterPromiseNode< + _::FixVoid, _::PromiseAndFulfillerAdapter>>(*wrapper)); auto promise = _::PromiseNode::to<_::ReducePromises>( - _::maybeChain(kj::mv(intermediate), implicitCast(nullptr))); + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); return PromiseFulfillerPair { kj::mv(promise), kj::mv(wrapper) }; } @@ -1423,10 +1684,11 @@ PromiseFulfillerPair newPromiseAndFulfiller() { namespace _ { // (private) -class XThreadEvent: private Event, // it's an event in the target thread - public PromiseNode { // it's a PromiseNode in the requesting thread +class XThreadEvent: public PromiseNode, // it's a PromiseNode in the requesting thread + private Event { // it's an event in the target thread public: - XThreadEvent(ExceptionOrValue& result, const Executor& targetExecutor, void* funcTracePtr); + XThreadEvent(ExceptionOrValue& result, const Executor& targetExecutor, EventLoop& loop, + void* funcTracePtr, SourceLocation location); void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; @@ -1436,7 +1698,7 @@ class XThreadEvent: private Event, // it's an event in the target thread // still being accessed by the other thread. (This can't be placed in ~XThreadEvent() because // that destructor doesn't run until the subclass has already been destroyed.) - virtual kj::Maybe> execute() = 0; + virtual kj::Maybe execute() = 0; // Run the function. If the function returns a promise, returns the inner PromiseNode, otherwise // returns null. @@ -1450,7 +1712,7 @@ class XThreadEvent: private Event, // it's an event in the target thread kj::Own targetExecutor; Maybe replyExecutor; // If executeAsync() was used. - kj::Maybe> promiseNode; + kj::Maybe promiseNode; // Accessed only in target thread. ListLink targetLink; @@ -1527,14 +1789,15 @@ template >> class XThreadEventImpl final: public XThreadEvent { // Implementation for a function that does not return a Promise. public: - XThreadEventImpl(Func&& func, const Executor& target) - : XThreadEvent(result, target, GetFunctorStartAddress<>::apply(func)), + XThreadEventImpl(Func&& func, const Executor& target, EventLoop& loop, SourceLocation location) + : XThreadEvent(result, target, loop, GetFunctorStartAddress<>::apply(func), location), func(kj::fwd(func)) {} ~XThreadEventImpl() noexcept(false) { ensureDoneOrCanceled(); } + void destroy() override { freePromise(this); } typedef _::FixVoid<_::ReturnType> ResultT; - kj::Maybe> execute() override { + kj::Maybe<_::OwnPromiseNode> execute() override { result.value = MaybeVoidCaller>::apply(func, Void()); return nullptr; } @@ -1554,14 +1817,15 @@ template class XThreadEventImpl> final: public XThreadEvent { // Implementation for a function that DOES return a Promise. public: - XThreadEventImpl(Func&& func, const Executor& target) - : XThreadEvent(result, target, GetFunctorStartAddress<>::apply(func)), + XThreadEventImpl(Func&& func, const Executor& target, EventLoop& loop, SourceLocation location) + : XThreadEvent(result, target, loop, GetFunctorStartAddress<>::apply(func), location), func(kj::fwd(func)) {} ~XThreadEventImpl() noexcept(false) { ensureDoneOrCanceled(); } + void destroy() override { freePromise(this); } typedef _::FixVoid<_::UnwrapPromise>> ResultT; - kj::Maybe> execute() override { + kj::Maybe<_::OwnPromiseNode> execute() override { auto result = _::PromiseNode::from(func()); KJ_IREQUIRE(result.get() != nullptr); return kj::mv(result); @@ -1581,15 +1845,19 @@ class XThreadEventImpl> final: public XThreadEvent { } // namespace _ (private) template -_::UnwrapPromise> Executor::executeSync(Func&& func) const { - _::XThreadEventImpl event(kj::fwd(func), *this); +_::UnwrapPromise> Executor::executeSync( + Func&& func, SourceLocation location) const { + _::XThreadEventImpl event(kj::fwd(func), *this, getLoop(), location); send(event, true); return convertToReturn(kj::mv(event.result)); } template -PromiseForResult Executor::executeAsync(Func&& func) const { - auto event = kj::heap<_::XThreadEventImpl>(kj::fwd(func), *this); +PromiseForResult Executor::executeAsync(Func&& func, SourceLocation location) const { + // HACK: We call getLoop() here, rather than have XThreadEvent's constructor do it, so that if it + // throws we don't crash due to `allocPromise()` being `noexcept`. + auto event = _::allocPromise<_::XThreadEventImpl>( + kj::fwd(func), *this, getLoop(), location); send(*event, false); return _::PromiseNode::to>(kj::mv(event)); } @@ -1605,12 +1873,7 @@ class XThreadPaf: public PromiseNode { public: XThreadPaf(); virtual ~XThreadPaf() noexcept(false); - - class Disposer: public kj::Disposer { - public: - void disposeImpl(void* pointer) const override; - }; - static const Disposer DISPOSER; + void destroy() override; // implements PromiseNode ---------------------------------------------------- void onReady(Event* event) noexcept override; @@ -1697,7 +1960,7 @@ class XThreadPaf::FulfillScope { ~FulfillScope() noexcept(false); - KJ_DISALLOW_COPY(FulfillScope); + KJ_DISALLOW_COPY_AND_MOVE(FulfillScope); bool shouldFulfill() { return obj != nullptr; } @@ -1761,11 +2024,322 @@ class XThreadFulfiller> { template PromiseCrossThreadFulfillerPair newPromiseAndCrossThreadFulfiller() { - kj::Own<_::XThreadPafImpl> node(new _::XThreadPafImpl, _::XThreadPaf::DISPOSER); + kj::Own<_::XThreadPafImpl, _::PromiseDisposer> node(new _::XThreadPafImpl); auto fulfiller = kj::heap<_::XThreadFulfiller>(node); return { _::PromiseNode::to<_::ReducePromises>(kj::mv(node)), kj::mv(fulfiller) }; } } // namespace kj +#if KJ_HAS_COROUTINE + +// ======================================================================================= +// Coroutines TS integration with kj::Promise. +// +// Here's a simple coroutine: +// +// Promise> connectToService(Network& n) { +// auto a = co_await n.parseAddress(IP, PORT); +// auto c = co_await a->connect(); +// co_return kj::mv(c); +// } +// +// The presence of the co_await and co_return keywords tell the compiler it is a coroutine. +// Although it looks similar to a function, it has a couple large differences. First, everything +// that would normally live in the stack frame lives instead in a heap-based coroutine frame. +// Second, the coroutine has the ability to return from its scope without deallocating this frame +// (to suspend, in other words), and the ability to resume from its last suspension point. +// +// In order to know how to suspend, resume, and return from a coroutine, the compiler looks up a +// coroutine implementation type via a traits class parameterized by the coroutine return and +// parameter types. We'll name our coroutine implementation `kj::_::Coroutine`, + +namespace kj::_ { template class Coroutine; } + +// Specializing the appropriate traits class tells the compiler about `kj::_::Coroutine`. + +namespace KJ_COROUTINE_STD_NAMESPACE { + +template +struct coroutine_traits, Args...> { + // `Args...` are the coroutine's parameter types. + + using promise_type = kj::_::Coroutine; + // The Coroutines TS calls this the "promise type". This makes sense when thinking of coroutines + // returning `std::future`, since the coroutine implementation would be a wrapper around + // a `std::promise`. It's extremely confusing from a KJ perspective, however, so I call it + // the "coroutine implementation type" instead. +}; + +} // namespace KJ_COROUTINE_STD_NAMESPACE + +// Now when the compiler sees our `connectToService()` coroutine above, it default-constructs a +// `coroutine_traits>, Network&>::promise_type`, or +// `kj::_::Coroutine>`. +// +// The implementation object lives in the heap-allocated coroutine frame. It gets destroyed and +// deallocated when the frame does. + +namespace kj::_ { + +namespace stdcoro = KJ_COROUTINE_STD_NAMESPACE; + +class CoroutineBase: public PromiseNode, + public Event { +public: + CoroutineBase(stdcoro::coroutine_handle<> coroutine, ExceptionOrValue& resultRef, + SourceLocation location); + ~CoroutineBase() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(CoroutineBase); + void destroy() override; + + auto initial_suspend() { return stdcoro::suspend_never(); } + auto final_suspend() noexcept { +#if _MSC_VER && !defined(__clang__) + // See comment at `finalSuspendCalled`'s definition. + finalSuspendCalled = true; +#endif + return stdcoro::suspend_always(); + } + // These adjust the suspension behavior of coroutines immediately upon initiation, and immediately + // after completion. + // + // The initial suspension point could allow us to defer the initial synchronous execution of a + // coroutine -- everything before its first co_await, that is. + // + // The final suspension point is useful to delay deallocation of the coroutine frame to match the + // lifetime of the enclosing promise. + + void unhandled_exception(); + +protected: + class AwaiterBase; + + bool isWaiting() { return waiting; } + void scheduleResumption() { + onReadyEvent.arm(); + waiting = false; + } + +private: + // ------------------------------------------------------- + // PromiseNode implementation + + void onReady(Event* event) noexcept override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; + + // ------------------------------------------------------- + // Event implementation + + Maybe> fire() override; + void traceEvent(TraceBuilder& builder) override; + + stdcoro::coroutine_handle<> coroutine; + ExceptionOrValue& resultRef; + + OnReadyEvent onReadyEvent; + bool waiting = true; + + bool hasSuspendedAtLeastOnce = false; + +#if _MSC_VER && !defined(__clang__) + bool finalSuspendCalled = false; + // MSVC erroneously reports the coroutine as done (that is, `coroutine.done()` returns true) + // seemingly as soon as `return_value()`/`return_void()` are called. This matters in our + // implementation of `unhandled_exception()`, which must arrange to propagate exceptions during + // coroutine frame unwind via the returned promise, even if `return_value()`/`return_void()` have + // already been called. To prove that our assumptions are correct in that function, we want to be + // able to assert that `final_suspend()` has not yet been called. This boolean hack allows us to + // preserve that assertion. +#endif + + Maybe promiseNodeForTrace; + // Whenever this coroutine is suspended waiting on another promise, we keep a reference to that + // promise so tracePromise()/traceEvent() can trace into it. + + UnwindDetector unwindDetector; + + struct DisposalResults { + bool destructorRan = false; + Maybe exception; + }; + Maybe maybeDisposalResults; + // Only non-null during destruction. Before calling coroutine.destroy(), our disposer sets this + // to point to a DisposalResults on the stack so unhandled_exception() will have some place to + // store unwind exceptions. We can't store them in this Coroutine, because we'll be destroyed once + // coroutine.destroy() has returned. Our disposer then rethrows as needed. +}; + +template +class CoroutineMixin; +// CRTP mixin, covered later. + +template +class Coroutine final: public CoroutineBase, + public CoroutineMixin, T> { + // The standard calls this the `promise_type` object. We can call this the "coroutine + // implementation object" since the word promise means different things in KJ and std styles. This + // is where we implement how a `kj::Promise` is returned from a coroutine, and how that promise + // is later fulfilled. We also fill in a few lifetime-related details. + // + // The implementation object is also where we can customize memory allocation of coroutine frames, + // by implementing a member `operator new(size_t, Args...)` (same `Args...` as in + // coroutine_traits). + // + // We can also customize how await-expressions are transformed within `kj::Promise`-based + // coroutines by implementing an `await_transform(P)` member function, where `P` is some type for + // which we want to implement co_await support, e.g. `kj::Promise`. This feature allows us to + // provide an optimized `kj::EventLoop` integration when the coroutine's return type and the + // await-expression's type are both `kj::Promise` instantiations -- see further comments under + // `await_transform()`. + +public: + using Handle = stdcoro::coroutine_handle>; + + Coroutine(SourceLocation location = {}) + : CoroutineBase(Handle::from_promise(*this), result, location) {} + + Promise get_return_object() { + // Called after coroutine frame construction and before initial_suspend() to create the + // coroutine's return object. `this` itself lives inside the coroutine frame, and we arrange for + // the returned Promise to own `this` via a custom Disposer and by always leaving the + // coroutine in a suspended state. + return PromiseNode::to>(OwnPromiseNode(this)); + } + +public: + template + class Awaiter; + + template + Awaiter await_transform(kj::Promise& promise) { return Awaiter(kj::mv(promise)); } + template + Awaiter await_transform(kj::Promise&& promise) { return Awaiter(kj::mv(promise)); } + // Called when someone writes `co_await promise`, where `promise` is a kj::Promise. We return + // an Awaiter, which implements coroutine suspension and resumption in terms of the KJ async + // event system. + // + // There is another hook we could implement: an `operator co_await()` free function. However, a + // free function would be unaware of the type of the enclosing coroutine. Since Awaiter is a + // member class template of Coroutine, it is able to implement an + // `await_suspend(Coroutine::Handle)` override, providing it type-safe access to our enclosing + // coroutine's PromiseNode. An `operator co_await()` free function would have to implement + // a type-erased `await_suspend(stdcoro::coroutine_handle)` override, and implement + // suspension and resumption in terms of .then(). Yuck! + +private: + // ------------------------------------------------------- + // PromiseNode implementation + + void get(ExceptionOrValue& output) noexcept override { + output.as>() = kj::mv(result); + } + + void fulfill(FixVoid&& value) { + // Called by the return_value()/return_void() functions in our mixin class. + + if (isWaiting()) { + result = kj::mv(value); + scheduleResumption(); + } + } + + ExceptionOr> result; + + friend class CoroutineMixin, T>; +}; + +template +class CoroutineMixin { +public: + void return_value(T value) { + static_cast(this)->fulfill(kj::mv(value)); + } +}; +template +class CoroutineMixin { +public: + void return_void() { + static_cast(this)->fulfill(_::Void()); + } +}; +// The Coroutines spec has no `_::FixVoid` equivalent to unify valueful and valueless co_return +// statements, and programs are ill-formed if the coroutine implementation object (Coroutine) has +// both a `return_value()` and `return_void()`. No amount of EnableIffery can get around it, so +// these return_* functions live in a CRTP mixin. + +class CoroutineBase::AwaiterBase { +public: + explicit AwaiterBase(OwnPromiseNode node); + AwaiterBase(AwaiterBase&&); + ~AwaiterBase() noexcept(false); + KJ_DISALLOW_COPY(AwaiterBase); + + bool await_ready() const { return false; } + // This could return "`node->get()` is safe to call" instead, which would make suspension-less + // co_awaits possible for immediately-fulfilled promises. However, we need an Event to figure that + // out, and we won't have access to the Coroutine Event until await_suspend() is called. So, we + // must return false here. Fortunately, await_suspend() has a trick up its sleeve to enable + // suspension-less co_awaits. + +protected: + void getImpl(ExceptionOrValue& result, void* awaitedAt); + bool awaitSuspendImpl(CoroutineBase& coroutineEvent); + +private: + UnwindDetector unwindDetector; + OwnPromiseNode node; + + Maybe maybeCoroutineEvent; + // If we do suspend waiting for our wrapped promise, we store a reference to `node` in our + // enclosing Coroutine for tracing purposes. To guard against any edge cases where an async stack + // trace is generated when an Awaiter was destroyed without Coroutine::fire() having been called, + // we need our own reference to the enclosing Coroutine. (I struggle to think up any such + // scenarios, but perhaps they could occur when destroying a suspended coroutine.) +}; + +template +template +class Coroutine::Awaiter: public AwaiterBase { + // Wrapper around a co_await'ed promise and some storage space for the result of that promise. + // The compiler arranges to call our await_suspend() to suspend, which arranges to be woken up + // when the awaited promise is settled. Once that happens, the enclosing coroutine's Event + // implementation resumes the coroutine, which transitively calls await_resume() to unwrap the + // awaited promise result. + +public: + explicit Awaiter(Promise promise): AwaiterBase(PromiseNode::from(kj::mv(promise))) {} + + KJ_NOINLINE U await_resume() { + // This is marked noinline in order to ensure __builtin_return_address() is accurate for stack + // trace purposes. In my experimentation, this method was not inlined anyway even in opt + // builds, but I want to make sure it doesn't suddenly start being inlined later causing stack + // traces to break. (I also tried always-inline, but this did not appear to cause the compiler + // to inline the method -- perhaps a limitation of coroutines?) +#if __GNUC__ + getImpl(result, __builtin_return_address(0)); +#elif _MSC_VER + getImpl(result, _ReturnAddress()); +#else + #error "please implement for your compiler" +#endif + auto value = kj::_::readMaybe(result.value); + KJ_IASSERT(value != nullptr, "Neither exception nor value present."); + return U(kj::mv(*value)); + } + + bool await_suspend(Coroutine::Handle coroutine) { + return awaitSuspendImpl(coroutine.promise()); + } + +private: + ExceptionOr> result; +}; + +#undef KJ_COROUTINE_STD_NAMESPACE + +} // namespace kj::_ (private) + +#endif // KJ_HAS_COROUTINE + KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-io-internal.h b/libs/EXTERNAL/capnproto/c++/src/kj/async-io-internal.h index 9f4dd01331e..d030ad9577f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-io-internal.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io-internal.h @@ -25,6 +25,10 @@ #include "vector.h" #include "async-io.h" #include +#include "one-of.h" +#include "cidr.h" + +KJ_BEGIN_HEADER struct sockaddr; struct sockaddr_un; @@ -40,32 +44,6 @@ kj::ArrayPtr safeUnixPath(const struct sockaddr_un* addr, uint addrl // paths MUST be read using this function. #endif -class CidrRange { -public: - CidrRange(StringPtr pattern); - - static CidrRange inet4(ArrayPtr bits, uint bitCount); - static CidrRange inet6(ArrayPtr prefix, ArrayPtr suffix, - uint bitCount); - // Zeros are inserted between `prefix` and `suffix` to extend the address to 128 bits. - - uint getSpecificity() const { return bitCount; } - - bool matches(const struct sockaddr* addr) const; - bool matchesFamily(int family) const; - - String toString() const; - -private: - int family; - byte bits[16]; - uint bitCount; // how many bits in `bits` need to match - - CidrRange(int family, ArrayPtr bits, uint bitCount); - - void zeroIrrelevantBits(); -}; - class NetworkFilter: public LowLevelAsyncIoProvider::NetworkFilter { public: NetworkFilter(); @@ -80,9 +58,13 @@ class NetworkFilter: public LowLevelAsyncIoProvider::NetworkFilter { Vector denyCidrs; bool allowUnix; bool allowAbstractUnix; + bool allowPublic = false; + bool allowNetwork = false; kj::Maybe next; }; } // namespace _ (private) } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-io-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io-test.c++ index dc454f4107a..e8892b79e67 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-io-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io-test.c++ @@ -30,10 +30,12 @@ #include "async-io-internal.h" #include "debug.h" #include "io.h" +#include "cidr.h" #include "miniposix.h" #include #include #include +#include #if _WIN32 #include #include "windows-sanity.h" @@ -89,7 +91,7 @@ TEST(AsyncIo, SimpleNetwork) { EXPECT_EQ("foo", result); } -#if !_WIN32 // TODO(0.10): Implement NetworkPeerIdentity for Win32. +#if !_WIN32 // TODO(someday): Implement NetworkPeerIdentity for Win32. TEST(AsyncIo, SimpleNetworkAuthentication) { auto ioContext = setupAsyncIo(); auto& network = ioContext.provider->getNetwork(); @@ -373,9 +375,17 @@ bool systemSupportsAddress(StringPtr addr, StringPtr service = nullptr) { // Can getaddrinfo() parse this addresses? This is only true if the address family (e.g., ipv6) // is configured on at least one interface. (The loopback interface usually has both ipv4 and // ipv6 configured, but not always.) + struct addrinfo hints; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = 0; + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG; + hints.ai_protocol = 0; + hints.ai_canonname = nullptr; + hints.ai_addr = nullptr; + hints.ai_next = nullptr; struct addrinfo* list; int status = getaddrinfo( - addr.cStr(), service == nullptr ? nullptr : service.cStr(), nullptr, &list); + addr.cStr(), service == nullptr ? nullptr : service.cStr(), &hints, &list); if (status == 0) { freeaddrinfo(list); return true; @@ -1072,12 +1082,12 @@ TEST(AsyncIo, AbstractUnixSocket) { #endif // __linux__ KJ_TEST("CIDR parsing") { - KJ_EXPECT(_::CidrRange("1.2.3.4/16").toString() == "1.2.0.0/16"); - KJ_EXPECT(_::CidrRange("1.2.255.4/18").toString() == "1.2.192.0/18"); - KJ_EXPECT(_::CidrRange("1234::abcd:ffff:ffff/98").toString() == "1234::abcd:c000:0/98"); + KJ_EXPECT(CidrRange("1.2.3.4/16").toString() == "1.2.0.0/16"); + KJ_EXPECT(CidrRange("1.2.255.4/18").toString() == "1.2.192.0/18"); + KJ_EXPECT(CidrRange("1234::abcd:ffff:ffff/98").toString() == "1234::abcd:c000:0/98"); - KJ_EXPECT(_::CidrRange::inet4({1,2,255,4}, 18).toString() == "1.2.192.0/18"); - KJ_EXPECT(_::CidrRange::inet6({0x1234, 0x5678}, {0xabcd, 0xffff, 0xffff}, 98).toString() == + KJ_EXPECT(CidrRange::inet4({1,2,255,4}, 18).toString() == "1.2.192.0/18"); + KJ_EXPECT(CidrRange::inet6({0x1234, 0x5678}, {0xabcd, 0xffff, 0xffff}, 98).toString() == "1234:5678::abcd:c000:0/98"); union { @@ -1090,37 +1100,37 @@ KJ_TEST("CIDR parsing") { { addr4.sin_family = AF_INET; addr4.sin_addr.s_addr = htonl(0x0102dfff); - KJ_EXPECT(_::CidrRange("1.2.255.255/18").matches(&addr)); - KJ_EXPECT(!_::CidrRange("1.2.255.255/19").matches(&addr)); - KJ_EXPECT(_::CidrRange("1.2.0.0/16").matches(&addr)); - KJ_EXPECT(!_::CidrRange("1.3.0.0/16").matches(&addr)); - KJ_EXPECT(_::CidrRange("1.2.223.255/32").matches(&addr)); - KJ_EXPECT(_::CidrRange("0.0.0.0/0").matches(&addr)); - KJ_EXPECT(!_::CidrRange("::/0").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.255.255/18").matches(&addr)); + KJ_EXPECT(!CidrRange("1.2.255.255/19").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.0.0/16").matches(&addr)); + KJ_EXPECT(!CidrRange("1.3.0.0/16").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.223.255/32").matches(&addr)); + KJ_EXPECT(CidrRange("0.0.0.0/0").matches(&addr)); + KJ_EXPECT(!CidrRange("::/0").matches(&addr)); } { addr4.sin_family = AF_INET6; byte bytes[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; memcpy(addr6.sin6_addr.s6_addr, bytes, 16); - KJ_EXPECT(_::CidrRange("0102:03ff::/24").matches(&addr)); - KJ_EXPECT(!_::CidrRange("0102:02ff::/24").matches(&addr)); - KJ_EXPECT(_::CidrRange("0102:02ff::/23").matches(&addr)); - KJ_EXPECT(_::CidrRange("0102:0304:0506:0708:090a:0b0c:0d0e:0f10/128").matches(&addr)); - KJ_EXPECT(_::CidrRange("::/0").matches(&addr)); - KJ_EXPECT(!_::CidrRange("0.0.0.0/0").matches(&addr)); + KJ_EXPECT(CidrRange("0102:03ff::/24").matches(&addr)); + KJ_EXPECT(!CidrRange("0102:02ff::/24").matches(&addr)); + KJ_EXPECT(CidrRange("0102:02ff::/23").matches(&addr)); + KJ_EXPECT(CidrRange("0102:0304:0506:0708:090a:0b0c:0d0e:0f10/128").matches(&addr)); + KJ_EXPECT(CidrRange("::/0").matches(&addr)); + KJ_EXPECT(!CidrRange("0.0.0.0/0").matches(&addr)); } { addr4.sin_family = AF_INET6; inet_pton(AF_INET6, "::ffff:1.2.223.255", &addr6.sin6_addr); - KJ_EXPECT(_::CidrRange("1.2.255.255/18").matches(&addr)); - KJ_EXPECT(!_::CidrRange("1.2.255.255/19").matches(&addr)); - KJ_EXPECT(_::CidrRange("1.2.0.0/16").matches(&addr)); - KJ_EXPECT(!_::CidrRange("1.3.0.0/16").matches(&addr)); - KJ_EXPECT(_::CidrRange("1.2.223.255/32").matches(&addr)); - KJ_EXPECT(_::CidrRange("0.0.0.0/0").matches(&addr)); - KJ_EXPECT(_::CidrRange("::/0").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.255.255/18").matches(&addr)); + KJ_EXPECT(!CidrRange("1.2.255.255/19").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.0.0/16").matches(&addr)); + KJ_EXPECT(!CidrRange("1.3.0.0/16").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.223.255/32").matches(&addr)); + KJ_EXPECT(CidrRange("0.0.0.0/0").matches(&addr)); + KJ_EXPECT(CidrRange("::/0").matches(&addr)); } } @@ -1191,6 +1201,58 @@ KJ_TEST("NetworkFilter") { KJ_EXPECT(allowed4(filter, "1.2.3.1")); KJ_EXPECT(!allowed4(filter, "1.2.3.4")); } + + // Test combinations of public/private/network/local. At one point these were buggy. + { + _::NetworkFilter filter({"public", "private"}, {}, base); + + KJ_EXPECT(allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(allowed4(filter, "192.168.0.1")); + KJ_EXPECT(allowed4(filter, "10.1.2.3")); + KJ_EXPECT(allowed4(filter, "127.0.0.1")); + KJ_EXPECT(allowed4(filter, "0.0.0.0")); + + KJ_EXPECT(allowed6(filter, "2400:cb00:2048:1::c629:d7a2")); + KJ_EXPECT(allowed6(filter, "fc00::1234")); + KJ_EXPECT(allowed6(filter, "::1")); + KJ_EXPECT(allowed6(filter, "::")); + } + + { + _::NetworkFilter filter({"network", "local"}, {}, base); + + KJ_EXPECT(allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(allowed4(filter, "192.168.0.1")); + KJ_EXPECT(allowed4(filter, "10.1.2.3")); + KJ_EXPECT(allowed4(filter, "127.0.0.1")); + KJ_EXPECT(allowed4(filter, "0.0.0.0")); + + KJ_EXPECT(allowed6(filter, "2400:cb00:2048:1::c629:d7a2")); + KJ_EXPECT(allowed6(filter, "fc00::1234")); + KJ_EXPECT(allowed6(filter, "::1")); + KJ_EXPECT(allowed6(filter, "::")); + } + + { + _::NetworkFilter filter({"public", "local"}, {}, base); + + KJ_EXPECT(allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(!allowed4(filter, "192.168.0.1")); + KJ_EXPECT(!allowed4(filter, "10.1.2.3")); + KJ_EXPECT(allowed4(filter, "127.0.0.1")); + KJ_EXPECT(allowed4(filter, "0.0.0.0")); + + KJ_EXPECT(allowed6(filter, "2400:cb00:2048:1::c629:d7a2")); + KJ_EXPECT(!allowed6(filter, "fc00::1234")); + KJ_EXPECT(allowed6(filter, "::1")); + KJ_EXPECT(allowed6(filter, "::")); + } } KJ_TEST("Network::restrictPeers()") { @@ -1226,7 +1288,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { auto buffer = kj::heapArray(expected.size()); auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); - return promise.then(kj::mvCapture(buffer, [&in,expected](kj::Array buffer, size_t amount) { + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { if (amount == 0) { KJ_FAIL_ASSERT("expected data never sent", expected); } @@ -1237,7 +1299,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { } return expectRead(in, expected.slice(amount)); - })); + }); } class MockAsyncInputStream final: public AsyncInputStream { @@ -2075,6 +2137,39 @@ KJ_TEST("Userland tee") { expectRead(*right, "foobar").wait(ws); } +KJ_TEST("Userland nested tee") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto tee = newTee(kj::mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto tee2 = newTee(kj::mv(right)); + auto rightLeft = kj::mv(tee2.branches[0]); + auto rightRight = kj::mv(tee2.branches[1]); + + auto writePromise = pipe.out->write("foobar", 6); + + expectRead(*left, "foobar").wait(ws); + writePromise.wait(ws); + expectRead(*rightLeft, "foobar").wait(ws); + expectRead(*rightRight, "foo").wait(ws); + + auto tee3 = newTee(kj::mv(rightRight)); + auto rightRightLeft = kj::mv(tee3.branches[0]); + auto rightRightRight = kj::mv(tee3.branches[1]); + expectRead(*rightRightLeft, "bar").wait(ws); + expectRead(*rightRightRight, "b").wait(ws); + + auto tee4 = newTee(kj::mv(rightRightRight)); + auto rightRightRightLeft = kj::mv(tee4.branches[0]); + auto rightRightRightRight = kj::mv(tee4.branches[1]); + expectRead(*rightRightRightLeft, "ar").wait(ws); + expectRead(*rightRightRightRight, "ar").wait(ws); +} + KJ_TEST("Userland tee concurrent read") { kj::EventLoop loop; WaitScope ws(loop); @@ -2690,7 +2785,7 @@ KJ_TEST("Userland tee pump cancellation implies write cancellation") { KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { leftPipe.out = nullptr; })) { - KJ_FAIL_EXPECT("write promises were not canceled", exception); + KJ_FAIL_EXPECT("write promises were not canceled", *exception); } } @@ -2885,5 +2980,447 @@ KJ_TEST("import socket FD that's already broken") { #endif // !__CYGWIN__ #endif // !_WIN32 +KJ_TEST("AggregateConnectionReceiver") { + EventLoop loop; + WaitScope ws(loop); + + auto pipe1 = newCapabilityPipe(); + auto pipe2 = newCapabilityPipe(); + + auto receiversBuilder = kj::heapArrayBuilder>(2); + receiversBuilder.add(kj::heap(*pipe1.ends[0])); + receiversBuilder.add(kj::heap(*pipe2.ends[0])); + + auto aggregate = newAggregateConnectionReceiver(receiversBuilder.finish()); + + CapabilityStreamNetworkAddress connector1(nullptr, *pipe1.ends[1]); + CapabilityStreamNetworkAddress connector2(nullptr, *pipe2.ends[1]); + + auto connectAndWrite = [&](NetworkAddress& addr, kj::StringPtr text) { + return addr.connect() + .then([text](Own stream) { + auto promise = stream->write(text.begin(), text.size()); + return promise.attach(kj::mv(stream)); + }).eagerlyEvaluate([](kj::Exception&& e) { + KJ_LOG(ERROR, e); + }); + }; + + auto acceptAndRead = [&](ConnectionReceiver& socket, kj::StringPtr expected) { + return socket + .accept().then([](Own stream) { + auto promise = stream->readAllText(); + return promise.attach(kj::mv(stream)); + }).then([expected](kj::String actual) { + KJ_EXPECT(actual == expected); + }).eagerlyEvaluate([](kj::Exception&& e) { + KJ_LOG(ERROR, e); + }); + }; + + auto connectPromise1 = connectAndWrite(connector1, "foo"); + KJ_EXPECT(!connectPromise1.poll(ws)); + auto connectPromise2 = connectAndWrite(connector2, "bar"); + KJ_EXPECT(!connectPromise2.poll(ws)); + + acceptAndRead(*aggregate, "foo").wait(ws); + + auto connectPromise3 = connectAndWrite(connector1, "baz"); + KJ_EXPECT(!connectPromise3.poll(ws)); + + acceptAndRead(*aggregate, "bar").wait(ws); + acceptAndRead(*aggregate, "baz").wait(ws); + + connectPromise1.wait(ws); + connectPromise2.wait(ws); + connectPromise3.wait(ws); + + auto acceptPromise1 = acceptAndRead(*aggregate, "qux"); + auto acceptPromise2 = acceptAndRead(*aggregate, "corge"); + auto acceptPromise3 = acceptAndRead(*aggregate, "grault"); + + KJ_EXPECT(!acceptPromise1.poll(ws)); + KJ_EXPECT(!acceptPromise2.poll(ws)); + KJ_EXPECT(!acceptPromise3.poll(ws)); + + // Cancel one of the acceptors... + { auto drop = kj::mv(acceptPromise2); } + + connectAndWrite(connector2, "qux").wait(ws); + connectAndWrite(connector1, "grault").wait(ws); + + acceptPromise1.wait(ws); + acceptPromise3.wait(ws); +} + +// ======================================================================================= +// Tests for optimized pumpTo() between OS handles. Note that this is only even optimized on +// some OSes (only Linux as of this writing), but the behavior should still be the same on all +// OSes, so we run the tests regardless. + +kj::String bigString(size_t size) { + auto result = kj::heapString(size); + for (auto i: kj::zeroTo(size)) { + result[i] = 'a' + i % 26; + } + return result; +} + +KJ_TEST("OS handle pumpTo") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + + { + auto readPromise = expectRead(*pipe2.ends[1], "foo"); + pipe1.ends[0]->write("foo", 3).wait(ws); + readPromise.wait(ws); + } + + { + auto readPromise = expectRead(*pipe2.ends[1], "bar"); + pipe1.ends[0]->write("bar", 3).wait(ws); + readPromise.wait(ws); + } + + auto two = bigString(2000); + auto four = bigString(4000); + auto eight = bigString(8000); + auto fiveHundred = bigString(500'000); + + { + auto readPromise = expectRead(*pipe2.ends[1], two); + pipe1.ends[0]->write(two.begin(), two.size()).wait(ws); + readPromise.wait(ws); + } + + { + auto readPromise = expectRead(*pipe2.ends[1], four); + pipe1.ends[0]->write(four.begin(), four.size()).wait(ws); + readPromise.wait(ws); + } + + { + auto readPromise = expectRead(*pipe2.ends[1], eight); + pipe1.ends[0]->write(eight.begin(), eight.size()).wait(ws); + readPromise.wait(ws); + } + + { + auto readPromise = expectRead(*pipe2.ends[1], fiveHundred); + pipe1.ends[0]->write(fiveHundred.begin(), fiveHundred.size()).wait(ws); + readPromise.wait(ws); + } + + KJ_EXPECT(!pump.poll(ws)) + pipe1.ends[0]->shutdownWrite(); + KJ_EXPECT(pump.wait(ws) == 6 + two.size() + four.size() + eight.size() + fiveHundred.size()); +} + +KJ_TEST("OS handle pumpTo small limit") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 500); + + auto text = bigString(1000); + + auto expected = kj::str(text.slice(0, 500)); + + auto readPromise = expectRead(*pipe2.ends[1], expected); + pipe1.ends[0]->write(text.begin(), text.size()).wait(ws); + auto secondWritePromise = pipe1.ends[0]->write(text.begin(), text.size()); + readPromise.wait(ws); + KJ_EXPECT(pump.wait(ws) == 500); + + expectRead(*pipe1.ends[1], text.slice(500)).wait(ws); +} + +KJ_TEST("OS handle pumpTo small limit -- write first then read") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto text = bigString(1000); + + auto expected = kj::str(text.slice(0, 500)); + + // Initiate the write first and let it put as much in the buffer as possible. + auto writePromise = pipe1.ends[0]->write(text.begin(), text.size()); + writePromise.poll(ws); + + // Now start the pump. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 500); + + auto readPromise = expectRead(*pipe2.ends[1], expected); + writePromise.wait(ws); + auto secondWritePromise = pipe1.ends[0]->write(text.begin(), text.size()); + readPromise.wait(ws); + KJ_EXPECT(pump.wait(ws) == 500); + + expectRead(*pipe1.ends[1], text.slice(500)).wait(ws); +} + +KJ_TEST("OS handle pumpTo large limit") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 750'000); + + auto text = bigString(500'000); + + auto expected = kj::str(text, text.slice(0, 250'000)); + + auto readPromise = expectRead(*pipe2.ends[1], expected); + pipe1.ends[0]->write(text.begin(), text.size()).wait(ws); + auto secondWritePromise = pipe1.ends[0]->write(text.begin(), text.size()); + readPromise.wait(ws); + KJ_EXPECT(pump.wait(ws) == 750'000); + + expectRead(*pipe1.ends[1], text.slice(250'000)).wait(ws); +} + +KJ_TEST("OS handle pumpTo large limit -- write first then read") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto text = bigString(500'000); + + auto expected = kj::str(text, text.slice(0, 250'000)); + + // Initiate the write first and let it put as much in the buffer as possible. + auto writePromise = pipe1.ends[0]->write(text.begin(), text.size()); + writePromise.poll(ws); + + // Now start the pump. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 750'000); + + auto readPromise = expectRead(*pipe2.ends[1], expected); + writePromise.wait(ws); + auto secondWritePromise = pipe1.ends[0]->write(text.begin(), text.size()); + readPromise.wait(ws); + KJ_EXPECT(pump.wait(ws) == 750'000); + + expectRead(*pipe1.ends[1], text.slice(250'000)).wait(ws); +} + +#if !_WIN32 +kj::String fillWriteBuffer(int fd) { + // Fill up the write buffer of the given FD and return the contents written. We need to use the + // raw syscalls to do this because KJ doesn't have a way to know how many bytes made it into the + // socket buffer. + auto huge = bigString(2'000'000); + + size_t pos = 0; + for (;;) { + KJ_ASSERT(pos < huge.size(), "whoa, big buffer"); + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = ::write(fd, huge.begin() + pos, huge.size() - pos)); + if (n < 0) break; + pos += n; + } + + return kj::str(huge.slice(0, pos)); +} + +KJ_TEST("OS handle pumpTo write buffer is full before pump") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto bufferContent = fillWriteBuffer(KJ_ASSERT_NONNULL(pipe2.ends[0]->getFd())); + + // Also prime the input pipe with some buffered bytes. + auto writePromise = pipe1.ends[0]->write("foo", 3); + writePromise.poll(ws); + + // Start the pump and let it get blocked. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + KJ_EXPECT(!pump.poll(ws)); + + // Queue another write, even. + writePromise = writePromise + .then([&]() { return pipe1.ends[0]->write("bar", 3); }); + writePromise.poll(ws); + + // See it all go through. + expectRead(*pipe2.ends[1], bufferContent).wait(ws); + expectRead(*pipe2.ends[1], "foobar").wait(ws); + + writePromise.wait(ws); + + pipe1.ends[0]->shutdownWrite(); + KJ_EXPECT(pump.wait(ws) == 6); + pipe2.ends[0]->shutdownWrite(); + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ws) == ""); +} + +KJ_TEST("OS handle pumpTo write buffer is full before pump -- and pump ends early") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto bufferContent = fillWriteBuffer(KJ_ASSERT_NONNULL(pipe2.ends[0]->getFd())); + + // Also prime the input pipe with some buffered bytes followed by EOF. + auto writePromise = pipe1.ends[0]->write("foo", 3) + .then([&]() { pipe1.ends[0]->shutdownWrite(); }); + writePromise.poll(ws); + + // Start the pump and let it get blocked. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + KJ_EXPECT(!pump.poll(ws)); + + // See it all go through. + expectRead(*pipe2.ends[1], bufferContent).wait(ws); + expectRead(*pipe2.ends[1], "foo").wait(ws); + + writePromise.wait(ws); + + KJ_EXPECT(pump.wait(ws) == 3); + pipe2.ends[0]->shutdownWrite(); + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ws) == ""); +} + +KJ_TEST("OS handle pumpTo write buffer is full before pump -- and pump hits limit early") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto bufferContent = fillWriteBuffer(KJ_ASSERT_NONNULL(pipe2.ends[0]->getFd())); + + // Also prime the input pipe with some buffered bytes followed by EOF. + auto writePromise = pipe1.ends[0]->write("foo", 3); + writePromise.poll(ws); + + // Start the pump and let it get blocked. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 3); + KJ_EXPECT(!pump.poll(ws)); + + // See it all go through. + expectRead(*pipe2.ends[1], bufferContent).wait(ws); + expectRead(*pipe2.ends[1], "foo").wait(ws); + + writePromise.wait(ws); + + KJ_EXPECT(pump.wait(ws) == 3); + pipe2.ends[0]->shutdownWrite(); + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ws) == ""); +} + +KJ_TEST("OS handle pumpTo write buffer is full before pump -- and a lot of data is pumped") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto bufferContent = fillWriteBuffer(KJ_ASSERT_NONNULL(pipe2.ends[0]->getFd())); + + // Also prime the input pipe with some buffered bytes followed by EOF. + auto text = bigString(500'000); + auto writePromise = pipe1.ends[0]->write(text.begin(), text.size()); + writePromise.poll(ws); + + // Start the pump and let it get blocked. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + KJ_EXPECT(!pump.poll(ws)); + + // See it all go through. + expectRead(*pipe2.ends[1], bufferContent).wait(ws); + expectRead(*pipe2.ends[1], text).wait(ws); + + writePromise.wait(ws); + + pipe1.ends[0]->shutdownWrite(); + KJ_EXPECT(pump.wait(ws) == text.size()); + pipe2.ends[0]->shutdownWrite(); + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ws) == ""); +} +#endif + +KJ_TEST("pump file to socket") { + // Tests sendfile() optimization + + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto doTest = [&](kj::Own file) { + file->writeAll("foobar"_kj.asBytes()); + + { + FileInputStream input(*file); + auto pipe = ioContext.provider->newTwoWayPipe(); + auto readPromise = pipe.ends[1]->readAllText(); + input.pumpTo(*pipe.ends[0]).wait(ws); + pipe.ends[0]->shutdownWrite(); + KJ_EXPECT(readPromise.wait(ws) == "foobar"); + KJ_EXPECT(input.getOffset() == 6); + } + + { + FileInputStream input(*file); + auto pipe = ioContext.provider->newTwoWayPipe(); + auto readPromise = pipe.ends[1]->readAllText(); + input.pumpTo(*pipe.ends[0], 3).wait(ws); + pipe.ends[0]->shutdownWrite(); + KJ_EXPECT(readPromise.wait(ws) == "foo"); + KJ_EXPECT(input.getOffset() == 3); + } + + { + FileInputStream input(*file, 3); + auto pipe = ioContext.provider->newTwoWayPipe(); + auto readPromise = pipe.ends[1]->readAllText(); + input.pumpTo(*pipe.ends[0]).wait(ws); + pipe.ends[0]->shutdownWrite(); + KJ_EXPECT(readPromise.wait(ws) == "bar"); + KJ_EXPECT(input.getOffset() == 6); + } + + auto big = bigString(500'000); + file->writeAll(big); + + { + FileInputStream input(*file); + auto pipe = ioContext.provider->newTwoWayPipe(); + auto readPromise = pipe.ends[1]->readAllText(); + input.pumpTo(*pipe.ends[0]).wait(ws); + pipe.ends[0]->shutdownWrite(); + // Extra parens here so that we don't write the big string to the console on failure... + KJ_EXPECT((readPromise.wait(ws) == big)); + KJ_EXPECT(input.getOffset() == big.size()); + } + }; + + // Try with an in-memory file. No optimization is possible. + doTest(kj::newInMemoryFile(kj::nullClock())); + + // Try with a disk file. Should use sendfile(). + auto fs = kj::newDiskFilesystem(); + doTest(fs->getCurrent().createTemporary()); +} + } // namespace } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-io-unix.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io-unix.c++ index 551f4c44449..62ce21323e8 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-io-unix.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io-unix.c++ @@ -26,6 +26,12 @@ #define _GNU_SOURCE #endif +#ifndef _FILE_OFFSET_BITS +#define _FILE_OFFSET_BITS 64 +// Request 64-bit off_t for sendfile(). (The code will still work if we get 32-bit off_t as long +// as actual files are under 4GB.) +#endif + #include "async-io.h" #include "async-io-internal.h" #include "async-unix.h" @@ -50,13 +56,18 @@ #include #include #include +#include + +#if __linux__ +#include +#endif #if !defined(SO_PEERCRED) && defined(LOCAL_PEERCRED) #include #endif -#if !defined(SOL_LOCAL) && (__FreeBSD__ || __DragonflyBSD__) -// On DragonFly or FreeBSD < 12.2 you're supposed to use 0 for SOL_LOCAL. +#if !defined(SOL_LOCAL) && (__FreeBSD__ || __DragonflyBSD__ || __APPLE__) +// On DragonFly, FreeBSD < 12.2 and older Darwin you're supposed to use 0 for SOL_LOCAL. #define SOL_LOCAL 0 #endif @@ -137,10 +148,10 @@ private: class AsyncStreamFd: public OwnedFileDescriptor, public AsyncCapabilityStream { public: - AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags) + AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags, uint observerFlags) : OwnedFileDescriptor(fd, flags), eventPort(eventPort), - observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ_WRITE) {} + observer(eventPort, fd, observerFlags) {} virtual ~AsyncStreamFd() noexcept(false) {} Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { @@ -163,7 +174,8 @@ public: (ReadResult result) mutable { for (auto i: kj::zeroTo(result.capCount)) { streamBuffer[i] = kj::heap(eventPort, fdBuffer[i].release(), - LowLevelAsyncIoProvider::TAKE_OWNERSHIP | LowLevelAsyncIoProvider::ALREADY_CLOEXEC); + LowLevelAsyncIoProvider::TAKE_OWNERSHIP | LowLevelAsyncIoProvider::ALREADY_CLOEXEC, + UnixEventPort::FdObserver::OBSERVE_READ_WRITE); } return result; }); @@ -228,6 +240,240 @@ public: return promise.attach(kj::mv(fds), kj::mv(streams)); } + Maybe> tryPumpFrom( + AsyncInputStream& input, uint64_t amount = kj::maxValue) override { +#if __linux__ && !__ANDROID__ + KJ_IF_MAYBE(sock, kj::dynamicDowncastIfAvailable(input)) { + return pumpFromOther(*sock, amount); + } +#endif + +#if __linux__ + KJ_IF_MAYBE(file, kj::dynamicDowncastIfAvailable(input)) { + KJ_IF_MAYBE(fd, file->getUnderlyingFile().getFd()) { + return pumpFromFile(*file, *fd, amount, 0); + } + } +#endif + + return nullptr; + } + +#if __linux__ + // TODO(someday): Support sendfile on other OS's... unfortunately, it works differently on + // different systems. + +private: + Promise pumpFromFile(FileInputStream& input, int fileFd, + uint64_t amount, uint64_t soFar) { + while (soFar < amount) { + off_t offset = input.getOffset(); + ssize_t n; + + // Although sendfile()'s last argument has type size_t, on Linux it seems to cause EINVAL + // if we pass an amount that is greater than UINT32_MAX, so make sure to clamp to that. In + // practice, of course, we'll be limited to the socket buffer size. + size_t requested = kj::min(amount - soFar, (uint32_t)kj::maxValue); + + KJ_SYSCALL_HANDLE_ERRORS(n = sendfile(fd, fileFd, &offset, requested)) { + case EINVAL: + case ENOSYS: + // Fall back to regular pump + return unoptimizedPumpTo(input, *this, amount, soFar); + + case EAGAIN: + return observer.whenBecomesWritable() + .then([this, &input, fileFd, amount, soFar]() { + return pumpFromFile(input, fileFd, amount, soFar); + }); + + default: + KJ_FAIL_SYSCALL("sendfile", error); + } + + if (n == 0) break; + + input.seek(offset); // NOTE: sendfile() updated `offset` in-place. + soFar += n; + } + + return soFar; + } + +public: +#endif // __linux__ + +#if __linux__ && !__ANDROID__ +// Linux's splice() syscall lets us optimize pumping of bytes between file descriptors. +// +// TODO(someday): splice()-based pumping hangs in unit tests on Android for some reason. We should +// figure out why, but for now I'm just disabling it... + +private: + Maybe> pumpFromOther(AsyncStreamFd& input, uint64_t amount) { + // The input is another AsyncStreamFd, so perhaps we can do an optimized pump with splice(). + + // Before we resort to a bunch of syscalls, let's try to see if the pump is small and able to + // be fully satisfied immediately. This optimizes for the case of small streams, e.g. a short + // HTTP body. + + byte buffer[4096]; + size_t pos = 0; + size_t initialAmount = kj::min(sizeof(buffer), amount); + + bool eof = false; + + // Read into the buffer until it's full or there are no bytes available. Note that we'd expect + // one call to read() will pull as much data out of the socket as possible (up to our buffer + // size), so you might think the loop is unnecessary. The reason we want to do a second read(), + // though, is to find out if we're at EOF or merely waiting for more data. In the EOF case, + // we can end the pump early without splicing. + while (pos < initialAmount) { + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = ::read(input.fd, buffer + pos, initialAmount - pos)); + if (n <= 0) { + eof = n == 0; + break; + } + pos += n; + } + + // Write the bytes that we just read back out to the output. + { + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = ::write(fd, buffer, pos)); + if (n < 0) n = 0; // treat EAGAIN as "zero bytes written" + if (size_t(n) < pos) { + // Oh crap, the output buffer is full. This should be rare. But, now we're going to have + // to copy the remaining bytes into the heap to do an async write. + auto leftover = kj::heapArray(buffer + n, pos - n); + auto promise = write(leftover.begin(), leftover.size()); + promise = promise.attach(kj::mv(leftover)); + if (eof || pos == amount) { + return promise.then([pos]() -> uint64_t { return pos; }); + } else { + return promise.then([&input, this, pos, amount]() { + return splicePumpFrom(input, pos, amount); + }); + } + } + } + + if (eof || pos == amount) { + // We finished the pump in one go, so don't splice. + return Promise(uint64_t(pos)); + } else { + // Use splice for the rest of the pump. + return splicePumpFrom(input, pos, amount); + } + } + + static constexpr size_t MAX_SPLICE_LEN = 1 << 20; + // Maximum value we'll pass for the `len` argument of `splice()`. Linux does not like it when we + // use `kj::maxValue` here so we clamp it. Note that the actual value of this constant is + // irrelevanta as long as it is more than the pipe buffer size (typically 64k) and less than + // whatever value makes Linux unhappy. All actual operations will be clamped to the buffer size. + // (And if the buffer size is for some reason larger than this, that's OK too, we just won't + // end up using the whole buffer.) + + Promise splicePumpFrom(AsyncStreamFd& input, uint64_t readSoFar, uint64_t limit) { + // splice() requires that either its input or its output is a pipe. But chances are neither + // `input.fd` nor `this->fd` is a pipe -- in most use cases they are sockets. In order to take + // advantage of splice(), then, we need to allocate a pipe to act as the middleman, so we can + // splice() from the input to the pipe, and then from the pipe to the output. + // + // You might wonder why this pipe middleman is required. Why can't splice() go directly from + // a socket to a socket? Linus Torvalds attempts to explain here: + // https://yarchive.net/comp/linux/splice.html + // + // The short version is that the pipe itself is equivalent to an in-memory buffer. In a naive + // pump implementation, we allocate a buffer, read() into it and write() out. With splice(), + // we allocate a kernelspace buffer by allocating a pipe, then we splice() into the pipe and + // splice() back out. + + // Linux normally allocates pipe buffers of 64k (16 pages of 4k each). However, when + // /proc/sys/fs/pipe-user-pages-soft is hit, then Linux will start allocating 4k (1 page) + // buffers instead, and will give an error if we try to increase it. + // + // The soft limit defaults to 16384 pages, which we'd hit after 1024 pipes -- totally possible + // in a big server. 64k is a nice buffer size, but even 4k is better than not using splice, so + // we'll live with whatever buffer size the kernel gives us. + // + // There is a second, "hard" limit, /proc/sys/fs/pipe-user-pages-hard, at which point Linux + // will start refusing to allocate pipes at all. In this case we fall back to an unoptimized + // pump. However, this limit defaults to unlimited, so this won't ever happen unless someone + // has manually changed the limit. That's probably dangerous since if the app allocates pipes + // anywhere else in its codebase, it probably doesn't have any fallbacks in those places, so + // things will break anyway... to avoid that we'd need to self-regulate the number of pipes + // we allocate here to avoid coming close to the hard limit, but that's a lot of effort so I'm + // not going to bother! + + int pipeFds[2]; + KJ_SYSCALL_HANDLE_ERRORS(pipe2(pipeFds, O_NONBLOCK | O_CLOEXEC)) { + case ENFILE: + // Probably hit the limit on pipe buffers, fall back to unoptimized pump. + return unoptimizedPumpTo(input, *this, limit, readSoFar); + default: + KJ_FAIL_SYSCALL("pipe2()", error); + } + + AutoCloseFd pipeIn(pipeFds[0]), pipeOut(pipeFds[1]); + + return splicePumpLoop(input, pipeFds[0], pipeFds[1], readSoFar, limit, 0) + .attach(kj::mv(pipeIn), kj::mv(pipeOut)); + } + + Promise splicePumpLoop(AsyncStreamFd& input, int pipeIn, int pipeOut, + uint64_t readSoFar, uint64_t limit, size_t bufferedAmount) { + for (;;) { + while (bufferedAmount > 0) { + // First flush out whatever is in the pipe buffer. + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = splice(pipeIn, nullptr, fd, nullptr, + MAX_SPLICE_LEN, SPLICE_F_MOVE | SPLICE_F_NONBLOCK)); + if (n > 0) { + KJ_ASSERT(n <= bufferedAmount, "splice pipe larger than bufferedAmount?"); + bufferedAmount -= n; + } else { + KJ_ASSERT(n < 0, "splice pipe empty before bufferedAmount reached?", bufferedAmount); + return observer.whenBecomesWritable() + .then([this, &input, pipeIn, pipeOut, readSoFar, limit, bufferedAmount]() { + return splicePumpLoop(input, pipeIn, pipeOut, readSoFar, limit, bufferedAmount); + }); + } + } + + // Now the pipe buffer is empty, so we can try to read some more. + { + if (readSoFar >= limit) { + // Hit the limit, we're done. + KJ_ASSERT(readSoFar == limit); + return readSoFar; + } + + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = splice(input.fd, nullptr, pipeOut, nullptr, + kj::min(limit - readSoFar, MAX_SPLICE_LEN), SPLICE_F_MOVE | SPLICE_F_NONBLOCK)); + if (n == 0) { + // EOF. + return readSoFar; + } else if (n < 0) { + // No data available, wait. + return input.observer.whenBecomesReadable() + .then([this, &input, pipeIn, pipeOut, readSoFar, limit]() { + return splicePumpLoop(input, pipeIn, pipeOut, readSoFar, limit, 0); + }); + } + + readSoFar += n; + bufferedAmount = n; + } + } + } + +public: +#endif // __linux__ && !__ANDROID__ + Promise whenWriteDisconnected() override { KJ_IF_MAYBE(p, writeDisconnectedPromise) { return p->addBranch(); @@ -634,6 +880,10 @@ private: } }; +#if __linux__ && !__ANDROID__ +constexpr size_t AsyncStreamFd::MAX_SPLICE_LEN; +#endif // __linux__ && !__ANDROID__ + // ======================================================================================= class SocketAddress { @@ -949,58 +1199,6 @@ private: } addr; struct LookupParams; - class LookupReader; -}; - -class SocketAddress::LookupReader { - // Reads SocketAddresses off of a pipe coming from another thread that is performing - // getaddrinfo. - -public: - LookupReader(kj::Own&& thread, kj::Own&& input, - _::NetworkFilter& filter) - : thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {} - - ~LookupReader() { - if (thread) thread->detach(); - } - - Promise> read() { - return input->tryRead(¤t, sizeof(current), sizeof(current)).then( - [this](size_t n) -> Promise> { - if (n < sizeof(current)) { - thread = nullptr; - // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check - // anyway. - KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; } - return addresses.releaseAsArray(); - } else { - // getaddrinfo() can return multiple copies of the same address for several reasons. - // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so - // it may return two copies of the same address, one for each type, unless it explicitly - // knows that the service name given is specific to one type. But we can't tell it a type, - // because we don't actually know which one the user wants, and if we specify SOCK_STREAM - // while the user specified a UDP service name then they'll get a resolution error which - // is lame. (At least, I think that's how it works.) - // - // So we instead resort to de-duping results. - if (alreadySeen.insert(current).second) { - if (current.parseAllowedBy(filter)) { - addresses.add(current); - } - } - return read(); - } - }); - } - -private: - kj::Own thread; - kj::Own input; - _::NetworkFilter& filter; - SocketAddress current; - kj::Vector addresses; - std::set alreadySeen; }; struct SocketAddress::LookupParams { @@ -1018,84 +1216,97 @@ Promise> SocketAddress::lookupHost( // Maybe use the various platform-specific asynchronous DNS libraries? Please do not implement // a custom DNS resolver... - int fds[2]; -#if __linux__ && !__BIONIC__ - KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC)); -#else - KJ_SYSCALL(pipe(fds)); -#endif - - auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS); - - int outFd = fds[1]; - + auto paf = newPromiseAndCrossThreadFulfiller>(); LookupParams params = { kj::mv(host), kj::mv(service) }; - auto thread = heap(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) { - FdOutputStream output((AutoCloseFd(outFd))); - - struct addrinfo* list; - int status = getaddrinfo( - params.host == "*" ? nullptr : params.host.cStr(), - params.service == nullptr ? nullptr : params.service.cStr(), - nullptr, &list); - if (status == 0) { - KJ_DEFER(freeaddrinfo(list)); - - struct addrinfo* cur = list; - while (cur != nullptr) { - if (params.service == nullptr) { - switch (cur->ai_addr->sa_family) { - case AF_INET: - ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint); - break; - case AF_INET6: - ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint); - break; - default: - break; + auto thread = heap( + [fulfiller=kj::mv(paf.fulfiller),params=kj::mv(params),portHint]() mutable { + // getaddrinfo() can return multiple copies of the same address for several reasons. + // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so + // it may return two copies of the same address, one for each type, unless it explicitly + // knows that the service name given is specific to one type. But we can't tell it a type, + // because we don't actually know which one the user wants, and if we specify SOCK_STREAM + // while the user specified a UDP service name then they'll get a resolution error which + // is lame. (At least, I think that's how it works.) + // + // So we instead resort to de-duping results. + std::set result; + + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; +#if __BIONIC__ + // AI_V4MAPPED causes getaddrinfo() to fail on Bionic libc (Android). + hints.ai_flags = AI_ADDRCONFIG; +#else + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG; +#endif + struct addrinfo* list; + int status = getaddrinfo( + params.host == "*" ? nullptr : params.host.cStr(), + params.service == nullptr ? nullptr : params.service.cStr(), + &hints, &list); + if (status == 0) { + KJ_DEFER(freeaddrinfo(list)); + + struct addrinfo* cur = list; + while (cur != nullptr) { + if (params.service == nullptr) { + switch (cur->ai_addr->sa_family) { + case AF_INET: + ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint); + break; + case AF_INET6: + ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint); + break; + default: + break; + } } - } - SocketAddress addr; - if (params.host == "*") { - // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo(). - addr.wildcard = true; - addr.addrlen = sizeof(addr.addr.inet6); - addr.addr.inet6.sin6_family = AF_INET6; - switch (cur->ai_addr->sa_family) { - case AF_INET: - addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port; - break; - case AF_INET6: - addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port; - break; - default: - addr.addr.inet6.sin6_port = portHint; - break; + SocketAddress addr; + if (params.host == "*") { + // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo(). + addr.wildcard = true; + addr.addrlen = sizeof(addr.addr.inet6); + addr.addr.inet6.sin6_family = AF_INET6; + switch (cur->ai_addr->sa_family) { + case AF_INET: + addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port; + break; + case AF_INET6: + addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port; + break; + default: + addr.addr.inet6.sin6_port = portHint; + break; + } + } else { + addr.addrlen = cur->ai_addrlen; + memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen); } - } else { - addr.addrlen = cur->ai_addrlen; - memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen); + result.insert(addr); + cur = cur->ai_next; + } + } else if (status == EAI_SYSTEM) { + KJ_FAIL_SYSCALL("getaddrinfo", errno, params.host, params.service) { + return; + } + } else { + KJ_FAIL_REQUIRE("DNS lookup failed.", + params.host, params.service, gai_strerror(status)) { + return; } - KJ_ASSERT_CAN_MEMCPY(SocketAddress); - output.write(&addr, sizeof(addr)); - cur = cur->ai_next; - } - } else if (status == EAI_SYSTEM) { - KJ_FAIL_SYSCALL("getaddrinfo", errno, params.host, params.service) { - return; } + })) { + fulfiller->reject(kj::mv(*exception)); } else { - KJ_FAIL_REQUIRE("DNS lookup failed.", - params.host, params.service, gai_strerror(status)) { - return; - } + fulfiller->fulfill(KJ_MAP(addr, result) { return addr; }); } - })); + }); - auto reader = heap(kj::mv(thread), kj::mv(input), filter); - return reader->read().attach(kj::mv(reader)); + return kj::mv(paf.promise); } // ======================================================================================= @@ -1155,7 +1366,8 @@ public: } AuthenticatedStream result; - result.stream = heap(eventPort, ownFd.release(), NEW_FD_FLAGS); + result.stream = heap(eventPort, ownFd.release(), NEW_FD_FLAGS, + UnixEventPort::FdObserver::OBSERVE_READ_WRITE); if (authenticated) { result.peerIdentity = SocketAddress(reinterpret_cast(&addr), addrlen) .getIdentity(lowLevel, filter, *result.stream); @@ -1263,27 +1475,28 @@ public: class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider { public: LowLevelAsyncIoProviderImpl() - : eventLoop(eventPort), waitScope(eventLoop) {} + : eventPort(), eventLoop(eventPort), waitScope(eventLoop) {} inline WaitScope& getWaitScope() { return waitScope; } Own wrapInputFd(int fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_READ); } Own wrapOutputFd(int fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_WRITE); } Own wrapSocketFd(int fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_READ_WRITE); } Own wrapUnixSocketFd(Fd fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_READ_WRITE); } Promise> wrapConnectingSocketFd( int fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override { // It's important that we construct the AsyncStreamFd first, so that `flags` are honored, // especially setting nonblocking mode and taking ownership. - auto result = heap(eventPort, fd, flags); + auto result = heap(eventPort, fd, flags, + UnixEventPort::FdObserver::OBSERVE_READ_WRITE); // Unfortunately connect() doesn't fit the mold of KJ_NONBLOCKING_SYSCALL, since it indicates // non-blocking using EINPROGRESS. @@ -1294,7 +1507,8 @@ public: // Fine. break; } else if (error != EINTR) { - KJ_FAIL_SYSCALL("connect()", error) { break; } + auto address = SocketAddress(addr, addrlen).toString(); + KJ_FAIL_SYSCALL("connect()", error, address) { break; } return Own(); } } else { @@ -1304,7 +1518,7 @@ public: } auto connected = result->waitConnected(); - return connected.then(kj::mvCapture(result, [fd](Own&& stream) { + return connected.then([fd,stream=kj::mv(result)]() mutable -> Own { int err; socklen_t errlen = sizeof(err); KJ_SYSCALL(getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &errlen)); @@ -1312,7 +1526,7 @@ public: KJ_FAIL_SYSCALL("connect()", err) { break; } } return kj::mv(stream); - })); + }); } Own wrapListenSocketFd( int fd, NetworkFilter& filter, uint flags = 0) override { @@ -1356,29 +1570,31 @@ public: } Own listen() override { - if (addrs.size() > 1) { - KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will " - "be used. If this is incorrect, specify the address numerically. This may be fixed " - "in the future.", addrs[0].toString()); - } + auto makeReceiver = [&](SocketAddress& addr) { + int fd = addr.socket(SOCK_STREAM); - int fd = addrs[0].socket(SOCK_STREAM); + { + KJ_ON_SCOPE_FAILURE(close(fd)); - { - KJ_ON_SCOPE_FAILURE(close(fd)); + // We always enable SO_REUSEADDR because having to take your server down for five minutes + // before it can restart really sucks. + int optval = 1; + KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval))); - // We always enable SO_REUSEADDR because having to take your server down for five minutes - // before it can restart really sucks. - int optval = 1; - KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval))); + addr.bind(fd); - addrs[0].bind(fd); + // TODO(someday): Let queue size be specified explicitly in string addresses. + KJ_SYSCALL(::listen(fd, SOMAXCONN)); + } - // TODO(someday): Let queue size be specified explicitly in string addresses. - KJ_SYSCALL(::listen(fd, SOMAXCONN)); - } + return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS); + }; - return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS); + if (addrs.size() == 1) { + return makeReceiver(addrs[0]); + } else { + return newAggregateConnectionReceiver(KJ_MAP(addr, addrs) { return makeReceiver(addr); }); + } } Own bindDatagramPort() override { @@ -1530,9 +1746,9 @@ public: : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {} Promise> parseAddress(StringPtr addr, uint portHint = 0) override { - return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) { + return evalNow([&]() { return SocketAddress::parse(lowLevel, addr, portHint, filter); - })).then([this](Array addresses) -> Own { + }).then([this](Array addresses) -> Own { return heap(lowLevel, filter, kj::mv(addresses)); }); } @@ -1814,13 +2030,12 @@ public: auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); - auto thread = heap(kj::mvCapture(startFunc, - [threadFd](Function&& startFunc) { + auto thread = heap([threadFd,startFunc=kj::mv(startFunc)]() mutable { LowLevelAsyncIoProviderImpl lowLevel; auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); AsyncIoProviderImpl ioProvider(lowLevel); startFunc(ioProvider, *stream, lowLevel.getWaitScope()); - })); + }); return { kj::mv(thread), kj::mv(pipe) }; } diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-io-win32.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io-win32.c++ index aaa65a20d39..30a9230f86d 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-io-win32.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io-win32.c++ @@ -23,7 +23,7 @@ // For Unix implementation, see async-io-unix.c++. // Request Vista-level APIs. -#include "win32-api-version.h" +#include #include "async-io.h" #include "async-io-internal.h" @@ -722,58 +722,6 @@ private: } addr; struct LookupParams; - class LookupReader; -}; - -class SocketAddress::LookupReader { - // Reads SocketAddresses off of a pipe coming from another thread that is performing - // getaddrinfo. - -public: - LookupReader(kj::Own&& thread, kj::Own&& input, - _::NetworkFilter& filter) - : thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {} - - ~LookupReader() { - if (thread) thread->detach(); - } - - Promise> read() { - return input->tryRead(¤t, sizeof(current), sizeof(current)).then( - [this](size_t n) -> Promise> { - if (n < sizeof(current)) { - thread = nullptr; - // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check - // anyway. - KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; } - return addresses.releaseAsArray(); - } else { - // getaddrinfo() can return multiple copies of the same address for several reasons. - // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so - // it may return two copies of the same address, one for each type, unless it explicitly - // knows that the service name given is specific to one type. But we can't tell it a type, - // because we don't actually know which one the user wants, and if we specify SOCK_STREAM - // while the user specified a UDP service name then they'll get a resolution error which - // is lame. (At least, I think that's how it works.) - // - // So we instead resort to de-duping results. - if (alreadySeen.insert(current).second) { - if (current.parseAllowedBy(filter)) { - addresses.add(current); - } - } - return read(); - } - }); - } - -private: - kj::Own thread; - kj::Own input; - _::NetworkFilter& filter; - SocketAddress current; - kj::Vector addresses; - std::set alreadySeen; }; struct SocketAddress::LookupParams { @@ -796,85 +744,84 @@ Promise> SocketAddress::lookupHost( // - Requires Unicode, for some reason. Only GetAddrInfoExW() supports async, according to the // docs. Never mind that DNS itself is ASCII... - SOCKET fds[2]; - KJ_WINSOCK(_::win32Socketpair(fds)); - - auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS); - - int outFd = fds[1]; - + auto paf = newPromiseAndCrossThreadFulfiller>(); LookupParams params = { kj::mv(host), kj::mv(service) }; - auto thread = heap(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) { - KJ_DEFER(closesocket(outFd)); - - struct addrinfo* list; - int status = getaddrinfo( - params.host == "*" ? nullptr : params.host.cStr(), - params.service == nullptr ? nullptr : params.service.cStr(), - nullptr, &list); - if (status == 0) { - KJ_DEFER(freeaddrinfo(list)); - - struct addrinfo* cur = list; - while (cur != nullptr) { - if (params.service == nullptr) { - switch (cur->ai_addr->sa_family) { - case AF_INET: - ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint); - break; - case AF_INET6: - ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint); - break; - default: - break; + auto thread = heap( + [fulfiller=kj::mv(paf.fulfiller),params=kj::mv(params),portHint]() mutable { + // getaddrinfo() can return multiple copies of the same address for several reasons. + // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so + // it may return two copies of the same address, one for each type, unless it explicitly + // knows that the service name given is specific to one type. But we can't tell it a type, + // because we don't actually know which one the user wants, and if we specify SOCK_STREAM + // while the user specified a UDP service name then they'll get a resolution error which + // is lame. (At least, I think that's how it works.) + // + // So we instead resort to de-duping results. + std::set result; + + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + addrinfo* list; + int status = getaddrinfo( + params.host == "*" ? nullptr : params.host.cStr(), + params.service == nullptr ? nullptr : params.service.cStr(), + nullptr, &list); + if (status == 0) { + KJ_DEFER(freeaddrinfo(list)); + + addrinfo* cur = list; + while (cur != nullptr) { + if (params.service == nullptr) { + switch (cur->ai_addr->sa_family) { + case AF_INET: + ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint); + break; + case AF_INET6: + ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint); + break; + default: + break; + } } - } - SocketAddress addr; - memset(&addr, 0, sizeof(addr)); // mollify valgrind - if (params.host == "*") { - // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo(). - addr.wildcard = true; - addr.addrlen = sizeof(addr.addr.inet6); - addr.addr.inet6.sin6_family = AF_INET6; - switch (cur->ai_addr->sa_family) { - case AF_INET: - addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port; - break; - case AF_INET6: - addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port; - break; - default: - addr.addr.inet6.sin6_port = portHint; - break; + SocketAddress addr; + memset(&addr, 0, sizeof(addr)); // mollify valgrind + if (params.host == "*") { + // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo(). + addr.wildcard = true; + addr.addrlen = sizeof(addr.addr.inet6); + addr.addr.inet6.sin6_family = AF_INET6; + switch (cur->ai_addr->sa_family) { + case AF_INET: + addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port; + break; + case AF_INET6: + addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port; + break; + default: + addr.addr.inet6.sin6_port = portHint; + break; + } + } else { + addr.addrlen = cur->ai_addrlen; + memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen); } - } else { - addr.addrlen = cur->ai_addrlen; - memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen); + result.insert(addr); + cur = cur->ai_next; } - KJ_ASSERT_CAN_MEMCPY(SocketAddress); - - const char* data = reinterpret_cast(&addr); - size_t size = sizeof(addr); - while (size > 0) { - int n; - KJ_WINSOCK(n = send(outFd, data, size, 0)); - data += n; - size -= n; + } else { + KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) { + return; } - - cur = cur->ai_next; } + })) { + fulfiller->reject(kj::mv(*exception)); } else { - KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) { - return; - } + fulfiller->fulfill(KJ_MAP(addr, result) { return addr; }); } - })); + }); - auto reader = heap(kj::mv(thread), kj::mv(input), filter); - return reader->read().attach(kj::mv(reader)); + return kj::mv(paf.promise); } // ======================================================================================= @@ -914,9 +861,9 @@ public: } } - return op->onComplete().then(mvCapture(result, mvCapture(scratch, - [this,newFd] - (Array scratch, Own stream, Win32EventPort::IoResult ioResult) + return op->onComplete().then( + [this,newFd,stream=kj::mv(result),scratch=kj::mv(scratch)] + (Win32EventPort::IoResult ioResult) mutable -> Promise> { if (ioResult.errorCode != ERROR_SUCCESS) { KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; } @@ -933,11 +880,11 @@ public: // getpeername() to get the address. auto addr = SocketAddress::getPeerAddress(newFd); if (addr.allowedBy(filter)) { - return kj::mv(stream); + return Own(kj::mv(stream)); } else { return accept(); } - }))); + }); } uint getPort() override { @@ -994,9 +941,9 @@ public: SocketAddress::getWildcardForFamily(addr->sa_family).bind(fd); auto connected = result->connect(addr, addrlen); - return connected.then(kj::mvCapture(result, [](Own&& result) { - return kj::mv(result); - })); + return connected.then([result=kj::mv(result)]() mutable -> Own { + return Own(kj::mv(result)); + }); } Own wrapListenSocketFd( SOCKET fd, NetworkFilter& filter, uint flags = 0) override { @@ -1139,9 +1086,9 @@ public: : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {} Promise> parseAddress(StringPtr addr, uint portHint = 0) override { - return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) { + return evalNow([&]() { return SocketAddress::parse(lowLevel, addr, portHint, filter); - })).then([this](Array addresses) -> Own { + }).then([this](Array addresses) -> Own { return heap(lowLevel, filter, kj::mv(addresses)); }); } @@ -1203,13 +1150,12 @@ public: auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); - auto thread = heap(kj::mvCapture(startFunc, - [threadFd](Function&& startFunc) { + auto thread = heap([threadFd,startFunc=kj::mv(startFunc)]() mutable { LowLevelAsyncIoProviderImpl lowLevel; auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); AsyncIoProviderImpl ioProvider(lowLevel); startFunc(ioProvider, *stream, lowLevel.getWaitScope()); - })); + }); return { kj::mv(thread), kj::mv(pipe) }; } diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-io.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io.c++ index ec9fc9d71fb..5fea50d3862 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-io.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io.c++ @@ -21,7 +21,7 @@ #if _WIN32 // Request Vista-level APIs. -#include "win32-api-version.h" +#include #endif #include "async-io.h" @@ -31,12 +31,13 @@ #include "io.h" #include "one-of.h" #include +#include #if _WIN32 #include #include #include -#include "windows-sanity.h" +#include #define inet_pton InetPtonA #define inet_ntop InetNtopA #include @@ -75,12 +76,16 @@ void AsyncInputStream::registerAncillaryMessageHandler( KJ_UNIMPLEMENTED("registerAncillaryMsgHandler is not implemented by this AsyncInputStream"); } +Maybe> AsyncInputStream::tryTee(uint64_t) { + return nullptr; +} + namespace { class AsyncPump { public: - AsyncPump(AsyncInputStream& input, AsyncOutputStream& output, uint64_t limit) - : input(input), output(output), limit(limit) {} + AsyncPump(AsyncInputStream& input, AsyncOutputStream& output, uint64_t limit, uint64_t doneSoFar) + : input(input), output(output), limit(limit), doneSoFar(doneSoFar) {} Promise pump() { // TODO(perf): This could be more efficient by reading half a buffer at a time and then @@ -104,12 +109,20 @@ private: AsyncInputStream& input; AsyncOutputStream& output; uint64_t limit; - uint64_t doneSoFar = 0; + uint64_t doneSoFar; byte buffer[4096]; }; } // namespace +Promise unoptimizedPumpTo( + AsyncInputStream& input, AsyncOutputStream& output, uint64_t amount, + uint64_t completedSoFar) { + auto pump = heap(input, output, amount, completedSoFar); + auto promise = pump->pump(); + return promise.attach(kj::mv(pump)); +} + Promise AsyncInputStream::pumpTo( AsyncOutputStream& output, uint64_t amount) { // See if output wants to dispatch on us. @@ -118,9 +131,7 @@ Promise AsyncInputStream::pumpTo( } // OK, fall back to naive approach. - auto pump = heap(*this, output, amount); - auto promise = pump->pump(); - return promise.attach(kj::mv(pump)); + return unoptimizedPumpTo(*this, output, amount); } namespace { @@ -210,7 +221,7 @@ public: Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { if (minBytes == 0) { - return size_t(0); + return constPromise(); } else KJ_IF_MAYBE(s, state) { return s->tryRead(buffer, minBytes, maxBytes); } else { @@ -249,7 +260,7 @@ public: Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { if (amount == 0) { - return uint64_t(0); + return constPromise(); } else KJ_IF_MAYBE(s, state) { return s->pumpTo(output, amount); } else { @@ -337,7 +348,7 @@ public: Maybe> tryPumpFrom( AsyncInputStream& input, uint64_t amount) override { if (amount == 0) { - return Promise(uint64_t(0)); + return constPromise(); } else KJ_IF_MAYBE(s, state) { return s->tryPumpFrom(input, amount); } else { @@ -1405,7 +1416,7 @@ private: if (input.tryGetLength().orDefault(1) == 0) { // Yeah a pump would pump nothing. - return Promise(uint64_t(0)); + return constPromise(); } else { // While we *could* just return nullptr here, it would probably then fall back to a normal // buffered pump, which would allocate a big old buffer just to find there's nothing to @@ -1438,7 +1449,7 @@ private: public: Promise tryRead(void* readBufferPtr, size_t minBytes, size_t maxBytes) override { - return size_t(0); + return constPromise(); } Promise tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, AutoCloseFd* fdBuffer, size_t maxFds) override { @@ -1450,7 +1461,7 @@ private: return ReadResult { 0, 0 }; } Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { - return uint64_t(0); + return constPromise(); } void abortRead() override { // ignore @@ -1615,7 +1626,7 @@ public: } Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - if (limit == 0) return size_t(0); + if (limit == 0) return constPromise(); return inner->tryRead(buffer, kj::min(minBytes, limit), kj::min(maxBytes, limit)) .then([this,minBytes](size_t actual) { decreaseLimit(actual, minBytes); @@ -1624,7 +1635,7 @@ public: } Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { - if (limit == 0) return uint64_t(0); + if (limit == 0) return constPromise(); auto requested = kj::min(amount, limit); return inner->pumpTo(output, requested) .then([this,requested](uint64_t actual) { @@ -1680,51 +1691,131 @@ CapabilityPipe newCapabilityPipe() { namespace { class AsyncTee final: public Refcounted { + class Buffer { + public: + Buffer() = default; + + uint64_t consume(ArrayPtr& readBuffer, size_t& minBytes); + // Consume as many bytes as possible, copying them into `readBuffer`. Return the number of bytes + // consumed. + // + // `readBuffer` and `minBytes` are both assigned appropriate new values, such that after any + // call to `consume()`, `readBuffer` will point to the remaining slice of unwritten space, and + // `minBytes` will have been decremented (clamped to zero) by the amount of bytes read. That is, + // the read can be considered fulfilled if `minBytes` is zero after a call to `consume()`. + + Array> asArray(uint64_t minBytes, uint64_t& amount); + // Consume the first `minBytes` of the buffer (or the entire buffer) and return it in an Array + // of ArrayPtrs, suitable for passing to AsyncOutputStream.write(). The outer Array + // owns the underlying data. + + void produce(Array bytes); + // Enqueue a byte array to the end of the buffer list. + + bool empty() const; + uint64_t size() const; + + Buffer clone() const { + size_t size = 0; + for (const auto& buf: bufferList) { + size += buf.size(); + } + auto builder = heapArrayBuilder(size); + for (const auto& buf: bufferList) { + builder.addAll(buf); + } + std::deque> deque; + deque.emplace_back(builder.finish()); + return Buffer{mv(deque)}; + } + + private: + Buffer(std::deque>&& buffer) : bufferList(mv(buffer)) {} + + std::deque> bufferList; + }; + + class Sink; + public: - using BranchId = uint; + class Branch final: public AsyncInputStream { + public: + Branch(Own teeArg): tee(mv(teeArg)) { + tee->branches.add(*this); + } - explicit AsyncTee(Own inner, uint64_t bufferSizeLimit) - : inner(mv(inner)), bufferSizeLimit(bufferSizeLimit), length(this->inner->tryGetLength()) {} - ~AsyncTee() noexcept(false) { - bool hasBranches = false; - for (auto& branch: branches) { - hasBranches = hasBranches || branch != nullptr; + Branch(Own teeArg, Branch& cloneFrom) + : tee(mv(teeArg)), buffer(cloneFrom.buffer.clone()) { + tee->branches.add(*this); } - KJ_ASSERT(!hasBranches, "destroying AsyncTee with branch still alive") { - // Don't std::terminate(). - break; + + ~Branch() noexcept(false) { + KJ_ASSERT(link.isLinked()) { + // Don't std::terminate(). + return; + } + tee->branches.remove(*this); + + KJ_REQUIRE(sink == nullptr, + "destroying tee branch with operation still in-progress; probably going to segfault") { + // Don't std::terminate(). + break; + } } - } - void addBranch(BranchId branch) { - KJ_REQUIRE(branches[branch] == nullptr, "branch already exists"); - branches[branch] = Branch(); - } + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return tee->tryRead(*this, buffer, minBytes, maxBytes); + } - void removeBranch(BranchId branch) { - auto& state = KJ_REQUIRE_NONNULL(branches[branch], "branch was already destroyed"); - KJ_REQUIRE(state.sink == nullptr, - "destroying tee branch with operation still in-progress; probably going to segfault") { + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return tee->pumpTo(*this, output, amount); + } + + Maybe tryGetLength() override { + return tee->tryGetLength(*this); + } + + Maybe> tryTee(uint64_t limit) override { + if (tee->getBufferSizeLimit() != limit) { + // Cannot optimize this path as the limit has changed, so we need a new AsyncTee to manage + // the limit. + return nullptr; + } + + return kj::heap(addRef(*tee), *this); + } + + private: + Own tee; + ListLink link; + + Buffer buffer; + Maybe sink; + + friend class AsyncTee; + }; + + explicit AsyncTee(Own inner, uint64_t bufferSizeLimit) + : inner(mv(inner)), bufferSizeLimit(bufferSizeLimit), length(this->inner->tryGetLength()) {} + ~AsyncTee() noexcept(false) { + KJ_ASSERT(branches.size() == 0, "destroying AsyncTee with branch still alive") { // Don't std::terminate(). break; } - - branches[branch] = nullptr; } - Promise tryRead(BranchId branch, void* buffer, size_t minBytes, size_t maxBytes) { - auto& state = KJ_ASSERT_NONNULL(branches[branch]); - KJ_ASSERT(state.sink == nullptr); + Promise tryRead(Branch& branch, void* buffer, size_t minBytes, size_t maxBytes) { + KJ_ASSERT(branch.sink == nullptr); // If there is excess data in the buffer for us, slurp that up. auto readBuffer = arrayPtr(reinterpret_cast(buffer), maxBytes); - auto readSoFar = state.buffer.consume(readBuffer, minBytes); + auto readSoFar = branch.buffer.consume(readBuffer, minBytes); if (minBytes == 0) { return readSoFar; } - if (state.buffer.empty()) { + if (branch.buffer.empty()) { KJ_IF_MAYBE(reason, stoppage) { // Prefer a short read to an exception. The exception prevents the pull loop from adding any // data to the buffer, so `readSoFar` will be zero the next time someone calls `tryRead()`, @@ -1736,37 +1827,39 @@ public: } } - auto promise = newAdaptedPromise(state.sink, readBuffer, minBytes, readSoFar); + auto promise = newAdaptedPromise( + branch.sink, readBuffer, minBytes, readSoFar); ensurePulling(); return mv(promise); } - Maybe tryGetLength(BranchId branch) { - auto& state = KJ_ASSERT_NONNULL(branches[branch]); - - return length.map([&state](uint64_t amount) { - return amount + state.buffer.size(); + Maybe tryGetLength(Branch& branch) { + return length.map([&branch](uint64_t amount) { + return amount + branch.buffer.size(); }); } - Promise pumpTo(BranchId branch, AsyncOutputStream& output, uint64_t amount) { - auto& state = KJ_ASSERT_NONNULL(branches[branch]); - KJ_ASSERT(state.sink == nullptr); + uint64_t getBufferSizeLimit() const { + return bufferSizeLimit; + } + + Promise pumpTo(Branch& branch, AsyncOutputStream& output, uint64_t amount) { + KJ_ASSERT(branch.sink == nullptr); if (amount == 0) { return amount; } - if (state.buffer.empty()) { + if (branch.buffer.empty()) { KJ_IF_MAYBE(reason, stoppage) { if (reason->is()) { - return uint64_t(0); + return constPromise(); } return cp(reason->get()); } } - auto promise = newAdaptedPromise(state.sink, output, amount); + auto promise = newAdaptedPromise(branch.sink, output, amount); ensurePulling(); return mv(promise); } @@ -1775,32 +1868,6 @@ private: struct Eof {}; using Stoppage = OneOf; - class Buffer { - public: - uint64_t consume(ArrayPtr& readBuffer, size_t& minBytes); - // Consume as many bytes as possible, copying them into `readBuffer`. Return the number of bytes - // consumed. - // - // `readBuffer` and `minBytes` are both assigned appropriate new values, such that after any - // call to `consume()`, `readBuffer` will point to the remaining slice of unwritten space, and - // `minBytes` will have been decremented (clamped to zero) by the amount of bytes read. That is, - // the read can be considered fulfilled if `minBytes` is zero after a call to `consume()`. - - Array> asArray(uint64_t minBytes, uint64_t& amount); - // Consume the first `minBytes` of the buffer (or the entire buffer) and return it in an Array - // of ArrayPtrs, suitable for passing to AsyncOutputStream.write(). The outer Array - // owns the underlying data. - - void produce(Array bytes); - // Enqueue a byte array to the end of the buffer list. - - bool empty() const; - uint64_t size() const; - - private: - std::deque> bufferList; - }; - class Sink { public: struct Need { @@ -1843,7 +1910,7 @@ private: KJ_ASSERT(sinkLink == nullptr, "sink initiated with sink already in flight"); sinkLink = *this; } - KJ_DISALLOW_COPY(SinkBase); + KJ_DISALLOW_COPY_AND_MOVE(SinkBase); ~SinkBase() noexcept(false) { detach(); } void reject(Exception&& exception) override { @@ -1875,11 +1942,6 @@ private: Maybe& sinkLink; }; - struct Branch { - Buffer buffer; - Maybe sink; - }; - class ReadSink final: public SinkBase { public: explicit ReadSink(PromiseFulfiller& fulfiller, Maybe& registration, @@ -1999,18 +2061,14 @@ private: uint64_t minBytes = 0; uint64_t maxBytes = kj::maxValue; - uint nBranches = 0; uint nSinks = 0; - for (auto& state: branches) { - KJ_IF_MAYBE(s, state) { - ++nBranches; - KJ_IF_MAYBE(sink, s->sink) { - ++nSinks; - auto need = sink->need(); - minBytes = kj::max(minBytes, need.minBytes); - maxBytes = kj::min(maxBytes, need.maxBytes); - } + for (auto& branch: branches) { + KJ_IF_MAYBE(sink, branch.sink) { + ++nSinks; + auto need = sink->need(); + minBytes = kj::max(minBytes, need.minBytes); + maxBytes = kj::min(maxBytes, need.maxBytes); } } @@ -2041,11 +2099,9 @@ private: return pullLoop().eagerlyEvaluate([this](Exception&& exception) { // Exception from our loop, not from inner tryRead(). Something is broken; tell everybody! pulling = false; - for (auto& state: branches) { - KJ_IF_MAYBE(s, state) { - KJ_IF_MAYBE(sink, s->sink) { - sink->reject(KJ_EXCEPTION(FAILED, "Exception in tee loop", exception)); - } + for (auto& branch: branches) { + KJ_IF_MAYBE(sink, branch.sink) { + sink->reject(KJ_EXCEPTION(FAILED, "Exception in tee loop", exception)); } } }); @@ -2056,7 +2112,7 @@ private: Own inner; const uint64_t bufferSizeLimit = kj::maxValue; Maybe length; - Maybe branches[2]; + List branches; Maybe stoppage; Promise pullPromise = READY_NOW; bool pulling = false; @@ -2070,11 +2126,9 @@ private: Vector> promises; - for (auto& state: branches) { - KJ_IF_MAYBE(s, state) { - KJ_IF_MAYBE(sink, s->sink) { - promises.add(sink->fill(s->buffer, stoppage)); - } + for (auto& branch: branches) { + KJ_IF_MAYBE(sink, branch.sink) { + promises.add(sink->fill(branch.buffer, stoppage)); } } @@ -2108,13 +2162,11 @@ private: n.maxBytes = kj::min(n.maxBytes, MAX_BLOCK_SIZE); n.maxBytes = kj::min(n.maxBytes, bufferSizeLimit); n.maxBytes = kj::max(n.minBytes, n.maxBytes); - for (auto& state: branches) { - KJ_IF_MAYBE(s, state) { - // TODO(perf): buffer.size() is O(n) where n = # of individual heap-allocated byte arrays. - if (s->buffer.size() + n.maxBytes > bufferSizeLimit) { - stoppage = Stoppage(KJ_EXCEPTION(FAILED, "tee buffer size limit exceeded")); - return pullLoop(); - } + for (auto& branch: branches) { + // TODO(perf): buffer.size() is O(n) where n = # of individual heap-allocated byte arrays. + if (branch.buffer.size() + n.maxBytes > bufferSizeLimit) { + stoppage = Stoppage(KJ_EXCEPTION(FAILED, "tee buffer size limit exceeded")); + return pullLoop(); } } auto heapBuffer = heapArray(n.maxBytes); @@ -2142,19 +2194,17 @@ private: KJ_ASSERT(stoppage == nullptr); Maybe> bufferPtr = nullptr; - for (auto& state: branches) { - KJ_IF_MAYBE(s, state) { - // Prefer to move the buffer into the receiving branch's deque, rather than memcpy. - // - // TODO(perf): For the 2-branch case, this is fine, since the majority of the time - // only one buffer will be in use. If we generalize to the n-branch case, this would - // become memcpy-heavy. - KJ_IF_MAYBE(ptr, bufferPtr) { - s->buffer.produce(heapArray(*ptr)); - } else { - bufferPtr = ArrayPtr(heapBuffer); - s->buffer.produce(mv(heapBuffer)); - } + for (auto& branch: branches) { + // Prefer to move the buffer into the receiving branch's deque, rather than memcpy. + // + // TODO(perf): For the 2-branch case, this is fine, since the majority of the time + // only one buffer will be in use. If we generalize to the n-branch case, this would + // become memcpy-heavy. + KJ_IF_MAYBE(ptr, bufferPtr) { + branch.buffer.produce(heapArray(*ptr)); + } else { + bufferPtr = ArrayPtr(heapBuffer); + branch.buffer.produce(mv(heapBuffer)); } } @@ -2254,41 +2304,16 @@ uint64_t AsyncTee::Buffer::size() const { return result; } -class TeeBranch final: public AsyncInputStream { -public: - TeeBranch(Own tee, uint8_t branch): tee(mv(tee)), branch(branch) { - this->tee->addBranch(branch); - } - ~TeeBranch() noexcept(false) { - unwind.catchExceptionsIfUnwinding([&]() { - tee->removeBranch(branch); - }); - } - - Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return tee->tryRead(branch, buffer, minBytes, maxBytes); - } - - Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { - return tee->pumpTo(branch, output, amount); - } - - Maybe tryGetLength() override { - return tee->tryGetLength(branch); - } - -private: - Own tee; - const uint8_t branch; - UnwindDetector unwind; -}; - } // namespace Tee newTee(Own input, uint64_t limit) { + KJ_IF_MAYBE(t, input->tryTee(limit)) { + return { { mv(input), mv(*t) }}; + } + auto impl = refcounted(mv(input), limit); - Own branch1 = heap(addRef(*impl), 0); - Own branch2 = heap(mv(impl), 1); + Own branch1 = heap(addRef(*impl)); + Own branch2 = heap(mv(impl)); return { { mv(branch1), mv(branch2) } }; } @@ -2549,7 +2574,7 @@ kj::Promise>> AsyncCapabilityStream::tryReceive } KJ_REQUIRE(actual.capCount == 1, - "expected to receive a capability (e.g. file descirptor via SCM_RIGHTS), but didn't") { + "expected to receive a capability (e.g. file descriptor via SCM_RIGHTS), but didn't") { return nullptr; } @@ -2724,9 +2749,9 @@ Promise> CapabilityStreamNetworkAddress::connect() { } auto result = kj::mv(pipe.ends[0]); return inner.sendStream(kj::mv(pipe.ends[1])) - .then(kj::mvCapture(result, [](Own&& result) { - return kj::mv(result); - })); + .then([result=kj::mv(result)]() mutable { + return Own(kj::mv(result)); + }); } Promise CapabilityStreamNetworkAddress::connectAuthenticated() { return connect().then([](Own&& stream) { @@ -2744,160 +2769,197 @@ String CapabilityStreamNetworkAddress::toString() { return kj::str(""); } -// ======================================================================================= - -namespace _ { // private - -#if !_WIN32 +Promise FileInputStream::tryRead(void* buffer, size_t minBytes, size_t maxBytes) { + // Note that our contract with `minBytes` is that we should only return fewer than `minBytes` on + // EOF. A file read will only produce fewer than the requested number of bytes if EOF was reached. + // `minBytes` cannot be greater than `maxBytes`. So, this read satisfies the `minBytes` + // requirement. + size_t result = file.read(offset, arrayPtr(reinterpret_cast(buffer), maxBytes)); + offset += result; + return result; +} -kj::ArrayPtr safeUnixPath(const struct sockaddr_un* addr, uint addrlen) { - KJ_REQUIRE(addr->sun_family == AF_UNIX, "not a unix address"); - KJ_REQUIRE(addrlen >= offsetof(sockaddr_un, sun_path), "invalid unix address"); +Maybe FileInputStream::tryGetLength() { + uint64_t size = file.stat().size; + return offset < size ? size - offset : 0; +} - size_t maxPathlen = addrlen - offsetof(sockaddr_un, sun_path); +Promise FileOutputStream::write(const void* buffer, size_t size) { + file.write(offset, arrayPtr(reinterpret_cast(buffer), size)); + offset += size; + return kj::READY_NOW; +} - size_t pathlen; - if (maxPathlen > 0 && addr->sun_path[0] == '\0') { - // Linux "abstract" unix address - pathlen = strnlen(addr->sun_path + 1, maxPathlen - 1) + 1; - } else { - pathlen = strnlen(addr->sun_path, maxPathlen); +Promise FileOutputStream::write(ArrayPtr> pieces) { + // TODO(perf): Extend kj::File with an array-of-arrays write? + for (auto piece: pieces) { + file.write(offset, piece); + offset += piece.size(); } - return kj::arrayPtr(addr->sun_path, pathlen); + return kj::READY_NOW; } -#endif // !_WIN32 +Promise FileOutputStream::whenWriteDisconnected() { + return kj::NEVER_DONE; +} -CidrRange::CidrRange(StringPtr pattern) { - size_t slashPos = KJ_REQUIRE_NONNULL(pattern.findFirst('/'), "invalid CIDR", pattern); +// ======================================================================================= - bitCount = pattern.slice(slashPos + 1).parseAs(); +namespace { - KJ_STACK_ARRAY(char, addr, slashPos + 1, 128, 128); - memcpy(addr.begin(), pattern.begin(), slashPos); - addr[slashPos] = '\0'; +class AggregateConnectionReceiver final: public ConnectionReceiver { +public: + AggregateConnectionReceiver(Array> receiversParam) + : receivers(kj::mv(receiversParam)), + acceptTasks(kj::heapArray>>(receivers.size())) {} - if (pattern.findFirst(':') == nullptr) { - family = AF_INET; - KJ_REQUIRE(bitCount <= 32, "invalid CIDR", pattern); - } else { - family = AF_INET6; - KJ_REQUIRE(bitCount <= 128, "invalid CIDR", pattern); + Promise> accept() override { + return acceptAuthenticated().then([](AuthenticatedStream&& authenticated) { + return kj::mv(authenticated.stream); + }); } - KJ_ASSERT(inet_pton(family, addr.begin(), bits) > 0, "invalid CIDR", pattern); - zeroIrrelevantBits(); -} - -CidrRange::CidrRange(int family, ArrayPtr bits, uint bitCount) - : family(family), bitCount(bitCount) { - if (family == AF_INET) { - KJ_REQUIRE(bitCount <= 32); - } else { - KJ_REQUIRE(bitCount <= 128); + Promise acceptAuthenticated() override { + // Whenever our accept() is called, we want it to resolve to the first connection accepted by + // any of our child receivers. Naively, it may seem like we should call accept() on them all + // and exclusiveJoin() the results. Unfortunately, this might not work in a certain race + // condition: if two or more of our children receive connections simultaneously, both child + // accept() calls may return, but we'll only end up taking one and dropping the other. + // + // To avoid this problem, we must instead initiate `accept()` calls on all children, and even + // after one of them returns a result, we must allow the others to keep running. If we end up + // accepting any sockets from children when there is no outstanding accept() on the aggregate, + // we must put that socket into a backlog. We only restart accept() calls on children if the + // backlog is empty, and hence the maximum length of the backlog is the number of children + // minus 1. + + if (backlog.empty()) { + auto result = kj::newAdaptedPromise(*this); + ensureAllAccepting(); + return result; + } else { + auto result = kj::mv(backlog.front()); + backlog.pop_front(); + return result; + } } - KJ_REQUIRE(bits.size() * 8 >= bitCount); - size_t byteCount = (bitCount + 7) / 8; - memcpy(this->bits, bits.begin(), byteCount); - memset(this->bits + byteCount, 0, sizeof(this->bits) - byteCount); - - zeroIrrelevantBits(); -} -CidrRange CidrRange::inet4(ArrayPtr bits, uint bitCount) { - return CidrRange(AF_INET, bits, bitCount); -} -CidrRange CidrRange::inet6( - ArrayPtr prefix, ArrayPtr suffix, - uint bitCount) { - KJ_REQUIRE(prefix.size() + suffix.size() <= 8); - - byte bits[16] = { 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, }; - - for (size_t i: kj::indices(prefix)) { - bits[i * 2] = prefix[i] >> 8; - bits[i * 2 + 1] = prefix[i] & 0xff; + uint getPort() override { + return receivers[0]->getPort(); } - - byte* suffixBits = bits + (16 - suffix.size() * 2); - for (size_t i: kj::indices(suffix)) { - suffixBits[i * 2] = suffix[i] >> 8; - suffixBits[i * 2 + 1] = suffix[i] & 0xff; + void getsockopt(int level, int option, void* value, uint* length) override { + return receivers[0]->getsockopt(level, option, value, length); + } + void setsockopt(int level, int option, const void* value, uint length) override { + // Apply to all. + for (auto& r: receivers) { + r->setsockopt(level, option, value, length); + } + } + void getsockname(struct sockaddr* addr, uint* length) override { + return receivers[0]->getsockname(addr, length); } - return CidrRange(AF_INET6, bits, bitCount); -} - -bool CidrRange::matches(const struct sockaddr* addr) const { - const byte* otherBits; - - switch (family) { - case AF_INET: - if (addr->sa_family == AF_INET6) { - otherBits = reinterpret_cast(addr)->sin6_addr.s6_addr; - static constexpr byte V6MAPPED[12] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff }; - if (memcmp(otherBits, V6MAPPED, sizeof(V6MAPPED)) == 0) { - // We're an ipv4 range and the address is ipv6, but it's a "v6 mapped" address, meaning - // it's equivalent to an ipv4 address. Try to match against the ipv4 part. - otherBits = otherBits + sizeof(V6MAPPED); - } else { - return false; - } - } else if (addr->sa_family == AF_INET) { - otherBits = reinterpret_cast( - &reinterpret_cast(addr)->sin_addr.s_addr); - } else { - return false; +private: + Array> receivers; + Array>> acceptTasks; + + struct Waiter { + Waiter(PromiseFulfiller& fulfiller, + AggregateConnectionReceiver& parent) + : fulfiller(fulfiller), parent(parent) { + parent.waiters.add(*this); + } + ~Waiter() noexcept(false) { + if (link.isLinked()) { + parent.waiters.remove(*this); } + } - break; + PromiseFulfiller& fulfiller; + AggregateConnectionReceiver& parent; + ListLink link; + }; - case AF_INET6: - if (addr->sa_family != AF_INET6) return false; + List waiters; + std::deque> backlog; + // At least one of `waiters` or `backlog` is always empty. - otherBits = reinterpret_cast(addr)->sin6_addr.s6_addr; - break; + void ensureAllAccepting() { + for (auto i: kj::indices(receivers)) { + if (acceptTasks[i] == nullptr) { + acceptTasks[i] = acceptLoop(i); + } + } + } - default: - KJ_UNREACHABLE; + Promise acceptLoop(size_t index) { + return kj::evalNow([&]() { return receivers[index]->acceptAuthenticated(); }) + .then([this](AuthenticatedStream&& as) { + if (waiters.empty()) { + backlog.push_back(kj::mv(as)); + } else { + auto& waiter = waiters.front(); + waiter.fulfiller.fulfill(kj::mv(as)); + waiters.remove(waiter); + } + }, [this](Exception&& e) { + if (waiters.empty()) { + backlog.push_back(kj::mv(e)); + } else { + auto& waiter = waiters.front(); + waiter.fulfiller.reject(kj::mv(e)); + waiters.remove(waiter); + } + }).then([this, index]() -> Promise { + if (waiters.empty()) { + // Don't keep accepting if there's no one waiting. + // HACK: We can't cancel ourselves, so detach the task so we can null out the slot. + // We know that the promise we're detaching here is exactly the promise that's currently + // executing and has no further `.then()`s on it, so no further callbacks will run in + // detached state... we're just using `detach()` as a tricky way to have the event loop + // dispose of this promise later after we've returned. + // TODO(cleanup): This pattern has come up several times, we need a better way to handle + // it. + KJ_ASSERT_NONNULL(acceptTasks[index]).detach([](auto&&) {}); + acceptTasks[index] = nullptr; + return READY_NOW; + } else { + return acceptLoop(index); + } + }); } +}; - if (memcmp(bits, otherBits, bitCount / 8) != 0) return false; +} // namespace - return bitCount == 128 || - bits[bitCount / 8] == (otherBits[bitCount / 8] & (0xff00 >> (bitCount % 8))); +Own newAggregateConnectionReceiver(Array> receivers) { + return kj::heap(kj::mv(receivers)); } -bool CidrRange::matchesFamily(int family) const { - switch (family) { - case AF_INET: - return this->family == AF_INET; - case AF_INET6: - // Even if we're a v4 CIDR, we can match v6 addresses in the v4-mapped range. - return true; - default: - return false; - } -} +// ----------------------------------------------------------------------------- -String CidrRange::toString() const { - char result[128]; - KJ_ASSERT(inet_ntop(family, (void*)bits, result, sizeof(result)) == result); - return kj::str(result, '/', bitCount); -} +namespace _ { // private + +#if !_WIN32 + +kj::ArrayPtr safeUnixPath(const struct sockaddr_un* addr, uint addrlen) { + KJ_REQUIRE(addr->sun_family == AF_UNIX, "not a unix address"); + KJ_REQUIRE(addrlen >= offsetof(sockaddr_un, sun_path), "invalid unix address"); -void CidrRange::zeroIrrelevantBits() { - // Mask out insignificant bits of partial byte. - if (bitCount < 128) { - bits[bitCount / 8] &= 0xff00 >> (bitCount % 8); + size_t maxPathlen = addrlen - offsetof(sockaddr_un, sun_path); - // Zero the remaining bytes. - size_t n = bitCount / 8 + 1; - memset(bits + n, 0, sizeof(bits) - n); + size_t pathlen; + if (maxPathlen > 0 && addr->sun_path[0] == '\0') { + // Linux "abstract" unix address + pathlen = strnlen(addr->sun_path + 1, maxPathlen - 1) + 1; + } else { + pathlen = strnlen(addr->sun_path, maxPathlen); } + return kj::arrayPtr(addr->sun_path, pathlen); } -// ----------------------------------------------------------------------------- +#endif // !_WIN32 ArrayPtr localCidrs() { static const CidrRange result[] = { @@ -2962,6 +3024,13 @@ ArrayPtr exampleAddresses() { return kj::arrayPtr(result, kj::size(result)); } +bool matchesAny(ArrayPtr cidrs, const struct sockaddr* addr) { + for (auto& cidr: cidrs) { + if (cidr.matches(addr)) return true; + } + return false; +} + NetworkFilter::NetworkFilter() : allowUnix(true), allowAbstractUnix(true) { allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0)); @@ -2976,17 +3045,14 @@ NetworkFilter::NetworkFilter(ArrayPtr allow, ArrayPtrsa_family == AF_INET || addr->sa_family == AF_INET6) && + !matchesAny(privateCidrs(), addr) && !matchesAny(localCidrs(), addr)) { + allowed = true; + // Don't adjust allowSpecificity as this match has an effective specificity of zero. + } + } + + if (allowNetwork) { + if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && + !matchesAny(localCidrs(), addr)) { + allowed = true; + // Don't adjust allowSpecificity as this match has an effective specificity of zero. + } + } + for (auto& cidr: allowCidrs) { if (cidr.matches(addr)) { allowSpecificity = kj::max(allowSpecificity, cidr.getSpecificity()); @@ -3064,6 +3147,10 @@ bool NetworkFilter::shouldAllowParse(const struct sockaddr* addr, uint addrlen) } } else { #endif + if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && + (allowPublic || allowNetwork)) { + matched = true; + } for (auto& cidr: allowCidrs) { if (cidr.matchesFamily(addr->sa_family)) { matched = true; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-io.h b/libs/EXTERNAL/capnproto/c++/src/kj/async-io.h index 3be72377a91..deff884bd22 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-io.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-io.h @@ -22,9 +22,9 @@ #pragma once #include "async.h" -#include "function.h" -#include "thread.h" -#include "timer.h" +#include +#include +#include KJ_BEGIN_HEADER @@ -45,10 +45,13 @@ class AsyncOutputStream; class AsyncIoStream; class AncillaryMessage; +class ReadableFile; +class File; + // ======================================================================================= // Streaming I/O -class AsyncInputStream { +class AsyncInputStream: private AsyncObject { // Asynchronous equivalent of InputStream (from io.h). public: @@ -91,9 +94,18 @@ class AsyncInputStream { // The provided callback will be called whenever any are encountered. The messages passed to // the function do not live beyond when function returns. // Only supported on Unix (the default impl throws UNIMPLEMENTED). Most apps will not use this. + + virtual Maybe> tryTee(uint64_t limit = kj::maxValue); + // Primarily intended as an optimization for the `tee` call. Returns an input stream whose state + // is independent from this one but which will return the exact same set of bytes read going + // forward. limit is a total limit on the amount of memory, in bytes, which a tee implementation + // may use to buffer stream data. An implementation must throw an exception if a read operation + // would cause the limit to be exceeded. If tryTee() can see that the new limit is impossible to + // satisfy, it should return nullptr so that the pessimized path is taken in newTee. This is + // likely to arise if tryTee() is called twice with different limits on the same stream. }; -class AsyncOutputStream { +class AsyncOutputStream: private AsyncObject { // Asynchronous equivalent of OutputStream (from io.h). public: @@ -159,6 +171,20 @@ class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream { // isn't wrapping a file descriptor. }; +Promise unoptimizedPumpTo( + AsyncInputStream& input, AsyncOutputStream& output, uint64_t amount, + uint64_t completedSoFar = 0); +// Performs a pump using read() and write(), without calling the stream's pumpTo() nor +// tryPumpFrom() methods. This is intended to be used as a fallback by implementations of pumpTo() +// and tryPumpFrom() when they want to give up on optimization, but can't just call pumpTo() again +// because this would recursively retry the optimization. unoptimizedPumpTo() should only be called +// inside implementations of streams, never by the caller of a stream -- use the pumpTo() method +// instead. +// +// `completedSoFar` is the number of bytes out of `amount` that have already been pumped. This is +// provided for convenience for cases where the caller has already done some pumping before they +// give up. Otherwise, a `.then()` would need to be used to add the bytes to the final result. + class AsyncCapabilityStream: public AsyncIoStream { // An AsyncIoStream that also allows transmitting new stream objects and file descriptors // (capabilities, in the object-capability model sense), in addition to bytes. @@ -405,7 +431,7 @@ class UnknownPeerIdentity: public PeerIdentity { // ======================================================================================= // Accepting connections -class ConnectionReceiver { +class ConnectionReceiver: private AsyncObject { // Represents a server socket listening on a port. public: @@ -430,6 +456,10 @@ class ConnectionReceiver { // Same as the methods of AsyncIoStream. }; +Own newAggregateConnectionReceiver(Array> receivers); +// Create a ConnectionReceiver that listens on several other ConnectionReceivers and returns +// sockets from any of them. + // ======================================================================================= // Datagram I/O @@ -536,7 +566,7 @@ class DatagramPort { // ======================================================================================= // Networks -class NetworkAddress { +class NetworkAddress: private AsyncObject { // Represents a remote address to which the application can connect. public: @@ -972,6 +1002,68 @@ class CapabilityStreamNetworkAddress final: public NetworkAddress { AsyncCapabilityStream& inner; }; +class FileInputStream: public AsyncInputStream { + // InputStream that reads from a disk file -- and enables sendfile() optimization. + // + // Reads are performed synchronously -- no actual attempt is made to use asynchronous file I/O. + // True asynchronous file I/O is complicated and is mostly unnecessary in the presence of + // caching. Only certain niche programs can expect to benefit from it. For the rest, it's better + // to use regular syrchronous disk I/O, so that's what this class does. + // + // The real purpose of this class, aside from general convenience, is to enable sendfile() + // optimization. When you use this class's pumpTo() method, and the destination is a socket, + // the system will detect this and optimize to sendfile(), so that the file data never needs to + // be read into userspace. + // + // NOTE: As of this writing, sendfile() optimization is only implemented on Linux. + +public: + FileInputStream(const ReadableFile& file, uint64_t offset = 0) + : file(file), offset(offset) {} + + const ReadableFile& getUnderlyingFile() { return file; } + uint64_t getOffset() { return offset; } + void seek(uint64_t newOffset) { offset = newOffset; } + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes); + Maybe tryGetLength(); + + // (pumpTo() is not actually overridden here, but AsyncStreamFd's tryPumpFrom() will detect when + // the source is a file.) + +private: + const ReadableFile& file; + uint64_t offset; +}; + +class FileOutputStream: public AsyncOutputStream { + // OutputStream that writes to a disk file. + // + // As with FileInputStream, calls are not actually async. Async would be even less useful here + // because writes should usually land in cache anyway. + // + // sendfile() optimization does not apply when writing to a file, but on Linux, splice() can + // be used to achieve a similar effect. + // + // NOTE: As of this writing, splice() optimization is not implemented. + +public: + FileOutputStream(const File& file, uint64_t offset = 0) + : file(file), offset(offset) {} + + const File& getUnderlyingFile() { return file; } + uint64_t getOffset() { return offset; } + void seek(uint64_t newOffset) { offset = newOffset; } + + Promise write(const void* buffer, size_t size); + Promise write(ArrayPtr> pieces); + Promise whenWriteDisconnected(); + +private: + const File& file; + uint64_t offset; +}; + // ======================================================================================= // inline implementation details @@ -996,6 +1088,57 @@ inline ArrayPtr AncillaryMessage::asArray() const { return arrayPtr(reinterpret_cast(data.begin()), data.size() / sizeof(T)); } +class SecureNetworkWrapper { + // Abstract interface for a class which implements a "secure" network as a wrapper around an + // insecure one. "secure" means: + // * Connections to a server will only succeed if it can be verified that the requested hostname + // actually belongs to the responding server. + // * No man-in-the-middle attacker can potentially see the bytes sent and received. + // + // The typical implementation uses TLS. The object in this case could be configured to use cerain + // keys, certificates, etc. See kj/compat/tls.h for such an implementation. + // + // However, an implementation could use some other form of encryption, or might not need to use + // encryption at all. For example, imagine a kj::Network that exists only on a single machine, + // providing communications between various processes using unix sockets. Perhaps the "hostnames" + // are actually PIDs in this case. An implementation of such a network could verify the other + // side's identity using an `SCM_CREDENTIALS` auxiliary message, which cannot be forged. Once + // verified, there is no need to encrypt since unix sockets cannot be intercepted. + +public: + virtual kj::Promise> wrapServer(kj::Own stream) = 0; + // Act as the server side of a connection. The given stream is already connected to a client, but + // no authentication has occurred. The returned stream represents the secure transport once + // established. + + virtual kj::Promise> wrapClient( + kj::Own stream, kj::StringPtr expectedServerHostname) = 0; + // Act as the client side of a connection. The given stream is already connecetd to a server, but + // no authentication has occurred. This method will verify that the server actually is the given + // hostname, then return the stream representing a secure transport to that server. + + virtual kj::Promise wrapServer(kj::AuthenticatedStream stream) = 0; + virtual kj::Promise wrapClient( + kj::AuthenticatedStream stream, kj::StringPtr expectedServerHostname) = 0; + // Same as above, but implementing kj::AuthenticatedStream, which provides PeerIdentity objects + // with more details about the peer. The SecureNetworkWrapper will provide its own implementation + // of PeerIdentity with the specific details it is able to authenticate. + + virtual kj::Own wrapPort(kj::Own port) = 0; + // Wrap a connection listener. This is equivalent to calling wrapServer() on every connection + // received. + + virtual kj::Own wrapAddress( + kj::Own address, kj::StringPtr expectedServerHostname) = 0; + // Wrap a NetworkAddress. This is equivalent to calling `wrapClient()` on every connection + // formed by calling `connect()` on the address. + + virtual kj::Own wrapNetwork(kj::Network& network) = 0; + // Wrap a whole `kj::Network`. This automatically wraps everything constructed using the network. + // The network will only accept address strings that can be authenticated, and will automatically + // authenticate servers against those addresses when connecting to them. +}; + } // namespace kj KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-prelude.h b/libs/EXTERNAL/capnproto/c++/src/kj/async-prelude.h index c3fd4b19455..6289bf3fa05 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-prelude.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-prelude.h @@ -24,8 +24,28 @@ #pragma once -#include "exception.h" -#include "tuple.h" +#include +#include +#include + +// Detect whether or not we should enable kj::Promise coroutine integration. +// +// TODO(someday): Support coroutines with -fno-exceptions. +#if !KJ_NO_EXCEPTIONS +#ifdef __has_include +#if (__cpp_impl_coroutine >= 201902L) && __has_include() +// C++20 Coroutines detected. +#include +#define KJ_HAS_COROUTINE 1 +#define KJ_COROUTINE_STD_NAMESPACE std +#elif (__cpp_coroutines >= 201703L) && __has_include() +// Coroutines TS detected. +#include +#define KJ_HAS_COROUTINE 1 +#define KJ_COROUTINE_STD_NAMESPACE std::experimental +#endif +#endif +#endif KJ_BEGIN_HEADER @@ -37,9 +57,9 @@ class Promise; class WaitScope; class TaskSet; -template -Promise> joinPromises(Array>&& promises); -Promise joinPromises(Array>&& promises); +Promise joinPromises(Array>&& promises, SourceLocation location = {}); +Promise joinPromisesFailFast(Array>&& promises, SourceLocation location = {}); +// Out-of-line specialization of template function defined in async.h. namespace _ { // private @@ -195,16 +215,20 @@ class Event; class XThreadEvent; class XThreadPaf; +class PromiseDisposer; +using OwnPromiseNode = Own; +// PromiseNode uses a static disposer. + class PromiseBase { public: kj::String trace(); // Dump debug info about this promise. private: - Own node; + OwnPromiseNode node; PromiseBase() = default; - PromiseBase(Own&& node): node(kj::mv(node)) {} + PromiseBase(OwnPromiseNode&& node): node(kj::mv(node)) {} template friend class kj::Promise; @@ -212,18 +236,25 @@ class PromiseBase { }; void detach(kj::Promise&& promise); -void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope& waitScope); -bool pollImpl(_::PromiseNode& node, WaitScope& waitScope); +void waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, WaitScope& waitScope, + SourceLocation location); +bool pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location); Promise yield(); Promise yieldHarder(); -Own neverDone(); +OwnPromiseNode readyNow(); +OwnPromiseNode neverDone(); + +class ReadyNow { +public: + operator Promise() const; +}; class NeverDone { public: template operator Promise() const; - KJ_NORETURN(void wait(WaitScope& waitScope) const); + KJ_NORETURN(void wait(WaitScope& waitScope, SourceLocation location = {}) const); }; } // namespace _ (private) diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-queue-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-queue-test.c++ index c82aadd6465..3d8c8dd8fdf 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-queue-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-queue-test.c++ @@ -82,23 +82,23 @@ KJ_TEST("ProducerConsumerQueue with various amounts of producers and consumers") producerCount, consumerCount, kItemCount); // Make a vector to track our entries. auto bits = Vector(kItemCount); - for (size_t i = 0; i < kItemCount; ++i) { + for (auto i KJ_UNUSED : kj::zeroTo(kItemCount)) { bits.add(false); } // Make enough producers. auto producers = Vector(); - for (size_t i = 0; i < producerCount; ++i) { + for (auto i KJ_UNUSED : kj::zeroTo(producerCount)) { producers.add(test); } // Make enough consumers. auto consumers = Vector(); - for (size_t i = 0; i < consumerCount; ++i) { + for (auto i KJ_UNUSED : kj::zeroTo(consumerCount)) { consumers.add(test); } - for (size_t i = 0; i < kItemCount; ++i) { + for (auto i : kj::zeroTo(kItemCount)) { // Use a producer and a consumer for each entry. auto& producer = producers[i % producerCount]; @@ -117,7 +117,7 @@ KJ_TEST("ProducerConsumerQueue with various amounts of producers and consumers") promises.add(kj::mv(consumer.promise)); } joinPromises(promises.releaseAsArray()).wait(test.io.waitScope); - for (auto i = 0; i < kItemCount; ++i) { + for (auto i : kj::zeroTo(kItemCount)) { KJ_ASSERT(bits[i], i); } } @@ -132,7 +132,7 @@ KJ_TEST("ProducerConsumerQueue with rejectAll()") { // Make enough consumers. auto promises = Vector>(); - for (size_t i = 0; i < consumerCount; ++i) { + for (auto i KJ_UNUSED : kj::zeroTo(consumerCount)) { promises.add(test.queue.pop().ignoreResult()); } @@ -148,4 +148,4 @@ KJ_TEST("ProducerConsumerQueue with rejectAll()") { } } // namespace -} // namespace kj \ No newline at end of file +} // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-queue.h b/libs/EXTERNAL/capnproto/c++/src/kj/async-queue.h index 8e6f84c11ae..7a815faa35e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-queue.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-queue.h @@ -22,10 +22,10 @@ #pragma once #include "async.h" -#include "common.h" -#include "debug.h" -#include "list.h" -#include "memory.h" +#include +#include +#include +#include #include @@ -39,7 +39,7 @@ class WaiterQueue { // A WaiterQueue creates Nodes that blend newAdaptedPromise and List. WaiterQueue() = default; - KJ_DISALLOW_COPY(WaiterQueue); + KJ_DISALLOW_COPY_AND_MOVE(WaiterQueue); Promise wait() { return newAdaptedPromise(queue); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-test.c++ index fe11665de77..3ac6024cae2 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-test.c++ @@ -25,20 +25,55 @@ #include "mutex.h" #include "thread.h" -#if !KJ_USE_FIBERS +#if !KJ_USE_FIBERS && !_WIN32 #include #endif +#if KJ_USE_FIBERS && __linux__ +#include +#include +#endif + namespace kj { namespace { -#if !_MSC_VER || defined(__clang__) +#if !_MSC_VER // TODO(msvc): GetFunctorStartAddress is not supported on MSVC currently, so skip the test. TEST(Async, GetFunctorStartAddress) { EXPECT_TRUE(nullptr != _::GetFunctorStartAddress<>::apply([](){return 0;})); } #endif +#if KJ_USE_FIBERS +bool isLibcContextHandlingKnownBroken() { + // manylinux2014-x86's libc implements getcontext() to fail with ENOSYS. This is flagrantly + // against spec: getcontext() is not a syscall and is documented as never failing. Our configure + // script cannot detect this problem because it would require actually executing code to see + // what happens, which wouldn't work when cross-compiling. It would have been so much better if + // they had removed the symbol from libc entirely. But as a work-around, we will skip the tests + // when libc is broken. +#if __linux__ + static bool result = ([]() { + ucontext_t context; + if (getcontext(&context) < 0 && errno == ENOSYS) { + KJ_LOG(WARNING, + "This platform's libc is broken. Its getcontext() errors with ENOSYS. Fibers will not " + "work, so we'll skip the tests, but libkj was still built with fiber support, which " + "is broken. Please tell your libc maitnainer to remove the getcontext() function " + "entirely rather than provide an intentionally-broken version -- that way, the " + "configure script will detect that it should build libkj without fiber support."); + return true; + } else { + return false; + } + })(); + return result; +#else + return false; +#endif +} +#endif + TEST(Async, EvalVoid) { EventLoop loop; WaitScope waitScope(loop); @@ -192,9 +227,9 @@ TEST(Async, DeepChain) { // Create a ridiculous chain of promises. for (uint i = 0; i < 1000; i++) { - promise = evalLater(mvCapture(promise, [](Promise promise) { + promise = evalLater([promise=kj::mv(promise)]() mutable { return kj::mv(promise); - })); + }); } loop.run(); @@ -229,9 +264,9 @@ TEST(Async, DeepChain2) { // Create a ridiculous chain of promises. for (uint i = 0; i < 1000; i++) { - promise = evalLater(mvCapture(promise, [](Promise promise) { + promise = evalLater([promise=kj::mv(promise)]() mutable { return kj::mv(promise); - })); + }); } promise.wait(waitScope); @@ -268,9 +303,9 @@ TEST(Async, DeepChain3) { Promise makeChain2(uint i, Promise promise) { if (i > 0) { - return evalLater(mvCapture(promise, [i](Promise&& promise) -> Promise { + return evalLater([i, promise=kj::mv(promise)]() mutable -> Promise { return makeChain2(i - 1, kj::mv(promise)); - })); + }); } else { return kj::mv(promise); } @@ -632,36 +667,134 @@ TEST(Async, ExclusiveJoin) { } TEST(Async, ArrayJoin) { - EventLoop loop; - WaitScope waitScope(loop); + for (auto specificJoinPromisesOverload: { + +[](kj::Array> promises) { return joinPromises(kj::mv(promises)); }, + +[](kj::Array> promises) { return joinPromisesFailFast(kj::mv(promises)); } + }) { + EventLoop loop; + WaitScope waitScope(loop); - auto builder = heapArrayBuilder>(3); - builder.add(123); - builder.add(456); - builder.add(789); + auto builder = heapArrayBuilder>(3); + builder.add(123); + builder.add(456); + builder.add(789); - Promise> promise = joinPromises(builder.finish()); + Promise> promise = specificJoinPromisesOverload(builder.finish()); - auto result = promise.wait(waitScope); + auto result = promise.wait(waitScope); - ASSERT_EQ(3u, result.size()); - EXPECT_EQ(123, result[0]); - EXPECT_EQ(456, result[1]); - EXPECT_EQ(789, result[2]); + ASSERT_EQ(3u, result.size()); + EXPECT_EQ(123, result[0]); + EXPECT_EQ(456, result[1]); + EXPECT_EQ(789, result[2]); + } } TEST(Async, ArrayJoinVoid) { + for (auto specificJoinPromisesOverload: { + +[](kj::Array> promises) { return joinPromises(kj::mv(promises)); }, + +[](kj::Array> promises) { return joinPromisesFailFast(kj::mv(promises)); } + }) { + EventLoop loop; + WaitScope waitScope(loop); + + auto builder = heapArrayBuilder>(3); + builder.add(READY_NOW); + builder.add(READY_NOW); + builder.add(READY_NOW); + + Promise promise = specificJoinPromisesOverload(builder.finish()); + + promise.wait(waitScope); + } +} + +struct Pafs { + kj::Array> promises; + kj::Array>> fulfillers; +}; + +Pafs makeCompletionCountingPafs(uint count, uint& tasksCompleted) { + auto promisesBuilder = heapArrayBuilder>(count); + auto fulfillersBuilder = heapArrayBuilder>>(count); + + for (auto KJ_UNUSED value: zeroTo(count)) { + auto paf = newPromiseAndFulfiller(); + promisesBuilder.add(paf.promise.then([&tasksCompleted]() { + ++tasksCompleted; + })); + fulfillersBuilder.add(kj::mv(paf.fulfiller)); + } + + return { promisesBuilder.finish(), fulfillersBuilder.finish() }; +} + +TEST(Async, ArrayJoinException) { EventLoop loop; WaitScope waitScope(loop); - auto builder = heapArrayBuilder>(3); - builder.add(READY_NOW); - builder.add(READY_NOW); - builder.add(READY_NOW); + uint tasksCompleted = 0; + auto pafs = makeCompletionCountingPafs(5, tasksCompleted); + auto& fulfillers = pafs.fulfillers; + Promise promise = joinPromises(kj::mv(pafs.promises)); - Promise promise = joinPromises(builder.finish()); + { + uint i = 0; + KJ_EXPECT(tasksCompleted == 0); + + // Joined tasks are not completed early. + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == 0); + + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == 0); + + // Rejected tasks do not fail-fast. + fulfillers[i++]->reject(KJ_EXCEPTION(FAILED, "Test exception")); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == 0); + + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == 0); + + // The final fulfillment makes the promise ready. + fulfillers[i++]->fulfill(); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("Test exception", promise.wait(waitScope)); + KJ_EXPECT(tasksCompleted == 4); + } +} - promise.wait(waitScope); +TEST(Async, ArrayJoinFailFastException) { + EventLoop loop; + WaitScope waitScope(loop); + + uint tasksCompleted = 0; + auto pafs = makeCompletionCountingPafs(5, tasksCompleted); + auto& fulfillers = pafs.fulfillers; + Promise promise = joinPromisesFailFast(kj::mv(pafs.promises)); + + { + uint i = 0; + KJ_EXPECT(tasksCompleted == 0); + + // Joined tasks are completed eagerly, not waiting until the join node is awaited. + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == i); + + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == i); + + fulfillers[i++]->reject(KJ_EXCEPTION(FAILED, "Test exception")); + + // The first rejection makes the promise ready. + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("Test exception", promise.wait(waitScope)); + KJ_EXPECT(tasksCompleted == i - 1); + } } TEST(Async, Canceler) { @@ -737,6 +870,10 @@ TEST(Async, TaskSet) { EXPECT_EQ(1u, errorHandler.exceptionCount); } +#if KJ_USE_FIBERS || !_WIN32 +// This test requires either fibers or pthreads in order to limit the stack size. Currently we +// don't have a version that works on Windows without fibers, so skip the test there. + TEST(Async, LargeTaskSetDestruction) { static constexpr size_t stackSize = 200 * 1024; @@ -751,6 +888,8 @@ TEST(Async, LargeTaskSetDestruction) { }; #if KJ_USE_FIBERS + if (isLibcContextHandlingKnownBroken()) return; + EventLoop loop; WaitScope waitScope(loop); @@ -776,6 +915,8 @@ TEST(Async, LargeTaskSetDestruction) { #endif } +#endif // KJ_USE_FIBERS || !_WIN32 + TEST(Async, TaskSet) { EventLoop loop; WaitScope waitScope(loop); @@ -830,6 +971,52 @@ TEST(Async, TaskSetOnEmpty) { promise.wait(waitScope); } +KJ_TEST("TaskSet::clear()") { + EventLoop loop; + WaitScope waitScope(loop); + + class ClearOnError: public TaskSet::ErrorHandler { + public: + TaskSet* tasks; + void taskFailed(kj::Exception&& exception) override { + KJ_EXPECT(exception.getDescription().endsWith("example TaskSet failure")); + tasks->clear(); + } + }; + + ClearOnError errorHandler; + TaskSet tasks(errorHandler); + errorHandler.tasks = &tasks; + + auto doTest = [&](auto&& causeClear) { + KJ_EXPECT(tasks.isEmpty()); + + uint count = 0; + tasks.add(kj::Promise(kj::READY_NOW).attach(kj::defer([&]() { ++count; }))); + tasks.add(kj::Promise(kj::NEVER_DONE).attach(kj::defer([&]() { ++count; }))); + tasks.add(kj::Promise(kj::NEVER_DONE).attach(kj::defer([&]() { ++count; }))); + + auto onEmpty = tasks.onEmpty(); + KJ_EXPECT(!onEmpty.poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(!tasks.isEmpty()); + + causeClear(); + KJ_EXPECT(tasks.isEmpty()); + onEmpty.wait(waitScope); + KJ_EXPECT(count == 3); + }; + + // Try it where we just call clear() directly. + doTest([&]() { tasks.clear(); }); + + // Try causing clear() inside taskFailed(), ensuring that this is permitted. + doTest([&]() { + tasks.add(KJ_EXCEPTION(FAILED, "example TaskSet failure")); + waitScope.poll(); + }); +} + class DestructorDetector { public: DestructorDetector(bool& setTrue): setTrue(setTrue) {} @@ -972,6 +1159,46 @@ TEST(Async, Poll) { paf.promise.wait(waitScope); } +KJ_TEST("Maximum turn count during wait scope poll is enforced") { + EventLoop loop; + WaitScope waitScope(loop); + ErrorHandlerImpl errorHandler; + TaskSet tasks(errorHandler); + + auto evaluated1 = false; + tasks.add(evalLater([&]() { + evaluated1 = true; + })); + + auto evaluated2 = false; + tasks.add(evalLater([&]() { + evaluated2 = true; + })); + + auto evaluated3 = false; + tasks.add(evalLater([&]() { + evaluated3 = true; + })); + + uint count; + + // Check that only events up to a maximum are resolved: + count = waitScope.poll(2); + KJ_ASSERT(count == 2); + KJ_EXPECT(evaluated1); + KJ_EXPECT(evaluated2); + KJ_EXPECT(!evaluated3); + + // Get the last remaining event in the queue: + count = waitScope.poll(1); + KJ_ASSERT(count == 1); + KJ_EXPECT(evaluated3); + + // No more events: + count = waitScope.poll(1); + KJ_ASSERT(count == 0); +} + KJ_TEST("exclusiveJoin both events complete simultaneously") { // Previously, if both branches of an exclusiveJoin() completed simultaneously, then the parent // event could be armed twice. This is an error, but the exact results of this error depend on @@ -991,6 +1218,8 @@ KJ_TEST("exclusiveJoin both events complete simultaneously") { #if KJ_USE_FIBERS KJ_TEST("start a fiber") { + if (isLibcContextHandlingKnownBroken()) return; + EventLoop loop; WaitScope waitScope(loop); @@ -1012,6 +1241,8 @@ KJ_TEST("start a fiber") { } KJ_TEST("fiber promise chaining") { + if (isLibcContextHandlingKnownBroken()) return; + EventLoop loop; WaitScope waitScope(loop); @@ -1035,6 +1266,8 @@ KJ_TEST("fiber promise chaining") { } KJ_TEST("throw from a fiber") { + if (isLibcContextHandlingKnownBroken()) return; + EventLoop loop; WaitScope waitScope(loop); @@ -1058,6 +1291,8 @@ KJ_TEST("throw from a fiber") { // This test fails on MinGW 32-bit builds due to a compiler bug with exceptions + fibers: // https://sourceforge.net/p/mingw-w64/bugs/835/ KJ_TEST("cancel a fiber") { + if (isLibcContextHandlingKnownBroken()) return; + EventLoop loop; WaitScope waitScope(loop); @@ -1091,6 +1326,8 @@ KJ_TEST("cancel a fiber") { #endif KJ_TEST("fiber pool") { + if (isLibcContextHandlingKnownBroken()) return; + EventLoop loop; WaitScope waitScope(loop); FiberPool pool(65536); @@ -1109,7 +1346,12 @@ KJ_TEST("fiber pool") { if (i1_local == nullptr) { i1_local = &i; } else { +#if !KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // Verify that the stack variable is in the exact same spot as before. + // May not work under ASAN as the instrumentation to detect stack-use-after-return can + // change the address. KJ_ASSERT(i1_local == &i); +#endif } return i; }); @@ -1120,7 +1362,9 @@ KJ_TEST("fiber pool") { if (i2_local == nullptr) { i2_local = &i; } else { +#if !KJ_HAS_COMPILER_FEATURE(address_sanitizer) KJ_ASSERT(i2_local == &i); +#endif } return i; }); @@ -1148,8 +1392,8 @@ KJ_TEST("fiber pool") { } }; run(); - KJ_ASSERT_NONNULL(i1_local); - KJ_ASSERT_NONNULL(i2_local); + KJ_ASSERT(i1_local != nullptr); + KJ_ASSERT(i2_local != nullptr); // run the same thing and reuse the fibers run(); } @@ -1157,12 +1401,29 @@ KJ_TEST("fiber pool") { bool onOurStack(char* p) { // If p points less than 64k away from a random stack variable, then it must be on the same // stack, since we never allocate stacks smaller than 64k. +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // The stack-use-after-return detection mechanism breaks our ability to check this, so don't. + return true; +#else char c; ptrdiff_t diff = p - &c; return diff < 65536 && diff > -65536; +#endif +} + +bool notOnOurStack(char* p) { + // Opposite of onOurStack(), except returns true if the check can't be performed. +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // The stack-use-after-return detection mechanism breaks our ability to check this, so don't. + return true; +#else + return !onOurStack(p); +#endif } KJ_TEST("fiber pool runSynchronously()") { + if (isLibcContextHandlingKnownBroken()) return; + FiberPool pool(65536); { @@ -1185,17 +1446,22 @@ KJ_TEST("fiber pool runSynchronously()") { }); KJ_ASSERT(ptr2 != nullptr); +#if !KJ_HAS_COMPILER_FEATURE(address_sanitizer) // Should have used the same stack both times, so local var would be in the same place. + // Under ASAN, the stack-use-after-return detection correctly fires on this, so we skip the check. KJ_EXPECT(ptr1 == ptr2); +#endif // Should have been on a different stack from the main stack. - KJ_EXPECT(!onOurStack(ptr1)); + KJ_EXPECT(notOnOurStack(ptr1)); KJ_EXPECT_THROW_MESSAGE("test exception", pool.runSynchronously([&]() { KJ_FAIL_ASSERT("test exception"); })); } KJ_TEST("fiber pool limit") { + if (isLibcContextHandlingKnownBroken()) return; + FiberPool pool(65536); pool.setMaxFreelist(1); @@ -1241,7 +1507,7 @@ KJ_TEST("fiber pool limit") { // is the one from the thread. pool.runSynchronously([&]() { KJ_EXPECT(onOurStack(ptr2)); - KJ_EXPECT(!onOurStack(ptr1)); + KJ_EXPECT(notOnOurStack(ptr1)); KJ_EXPECT(pool.getFreelistSize() == 0); }); @@ -1253,7 +1519,15 @@ KJ_TEST("fiber pool limit") { // likelihood that the new stack would be allocated in the same location. } +#if __GNUC__ >= 12 && !__clang__ +// The test below intentionally takes a pointer to a stack variable and stores it past the end +// of the function. This seems to trigger a warning in newer GCCs. +#pragma GCC diagnostic ignored "-Wdangling-pointer" +#endif + KJ_TEST("run event loop on freelisted stacks") { + if (isLibcContextHandlingKnownBroken()) return; + FiberPool pool(65536); class MockEventPort: public EventPort { @@ -1310,15 +1584,15 @@ KJ_TEST("run event loop on freelisted stacks") { // The event callbacks should have run on a different stack, but the wait should have been on // the main stack. - KJ_EXPECT(!onOurStack(ptr1)); - KJ_EXPECT(!onOurStack(ptr2)); + KJ_EXPECT(notOnOurStack(ptr1)); + KJ_EXPECT(notOnOurStack(ptr2)); KJ_EXPECT(onOurStack(port.waitStack)); pool.runSynchronously([&]() { // This should run on the same stack where the event callbacks ran. KJ_EXPECT(onOurStack(ptr1)); KJ_EXPECT(onOurStack(ptr2)); - KJ_EXPECT(!onOurStack(port.waitStack)); + KJ_EXPECT(notOnOurStack(port.waitStack)); }); } @@ -1351,8 +1625,8 @@ KJ_TEST("run event loop on freelisted stacks") { // The event callback should have run on a different stack, and poll() should have run on // a separate stack too. - KJ_EXPECT(!onOurStack(ptr1)); - KJ_EXPECT(!onOurStack(port.pollStack)); + KJ_EXPECT(notOnOurStack(ptr1)); + KJ_EXPECT(notOnOurStack(port.pollStack)); pool.runSynchronously([&]() { // This should run on the same stack where the event callbacks ran. @@ -1429,5 +1703,45 @@ KJ_TEST("retryOnDisconnect") { } } +#if (__GLIBC__ == 2 && __GLIBC_MINOR__ <= 17) || (__MINGW32__ && !__MINGW64__) +// manylinux2014-x86 doesn't seem to respect `alignas(16)`. I am guessing this is a glibc issue +// but I don't really know. It uses glibc 2.17, so testing for that and skipping the test makes +// CI work. +// +// MinGW 32-bit also mysteriously fails this test but I am not going to spend time figuring out +// why. +#else +KJ_TEST("capture weird alignment in continuation") { + struct alignas(16) WeirdAlign { + ~WeirdAlign() { + KJ_EXPECT(reinterpret_cast(this) % 16 == 0); + } + int i; + }; + + EventLoop loop; + WaitScope waitScope(loop); + + kj::Promise p = kj::READY_NOW; + + WeirdAlign value = { 123 }; + WeirdAlign value2 = { 456 }; + auto p2 = p.then([value, value2]() -> WeirdAlign { + return { value.i + value2.i }; + }); + + KJ_EXPECT(p2.wait(waitScope).i == 579); +} +#endif + +KJ_TEST("constPromise") { + EventLoop loop; + WaitScope waitScope(loop); + + Promise p = constPromise(); + int i = p.wait(waitScope); + KJ_EXPECT(i == 123); +} + } // namespace } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-unix-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-unix-test.c++ index c8012e4ecb8..64190c430fa 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-unix-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-unix-test.c++ @@ -37,8 +37,17 @@ #include #include #include +#include #include "mutex.h" +#if KJ_USE_EPOLL +#include +#endif + +#if KJ_USE_KQUEUE +#include +#endif + #if __BIONIC__ // Android's Bionic defines SIGRTMIN but using it in sigaddset() throws EINVAL, which means we // definitely can't actually use RT signals. @@ -76,7 +85,67 @@ void captureSignals() { } } +#if KJ_USE_EPOLL +bool qemuBugTestSignalHandlerRan = false; +void qemuBugTestSignalHandler(int, siginfo_t* siginfo, void*) { + qemuBugTestSignalHandlerRan = true; +} + +bool checkForQemuEpollPwaitBug() { + // Under qemu-user, when a signal is delivered during epoll_pwait(), the signal successfully + // interrupts the wait, but the correct signal handler is not run. This ruins all our tests so + // we check for it and skip tests in this case. This does imply UnixEventPort won't be able to + // handle signals correctly under qemu-user. + + sigset_t mask; + sigset_t origMask; + KJ_SYSCALL(sigemptyset(&mask)); + KJ_SYSCALL(sigaddset(&mask, SIGURG)); + KJ_SYSCALL(pthread_sigmask(SIG_BLOCK, &mask, &origMask)); + KJ_DEFER(KJ_SYSCALL(pthread_sigmask(SIG_SETMASK, &origMask, nullptr))); + + struct sigaction action; + memset(&action, 0, sizeof(action)); + action.sa_sigaction = &qemuBugTestSignalHandler; + action.sa_flags = SA_SIGINFO; + + KJ_SYSCALL(sigfillset(&action.sa_mask)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGBUS)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGFPE)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGILL)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGSEGV)); + + KJ_SYSCALL(sigaction(SIGURG, &action, nullptr)); + + int efd; + KJ_SYSCALL(efd = epoll_create1(EPOLL_CLOEXEC)); + KJ_DEFER(close(efd)); + + kill(getpid(), SIGURG); + KJ_ASSERT(!qemuBugTestSignalHandlerRan); + + struct epoll_event event; + int n = epoll_pwait(efd, &event, 1, -1, &origMask); + KJ_ASSERT(n < 0); + KJ_ASSERT(errno == EINTR); + +#if !__aarch64__ + // qemu-user should only be used to execute aarch64 binaries so we should'nt see this bug + // elsewhere! + KJ_ASSERT(qemuBugTestSignalHandlerRan); +#endif + + return !qemuBugTestSignalHandlerRan; +} + +const bool BROKEN_QEMU = checkForQemuEpollPwaitBug(); +#else +const bool BROKEN_QEMU = false; +#endif + TEST(AsyncUnixTest, Signals) { + if (BROKEN_QEMU) return; + captureSignals(); UnixEventPort port; EventLoop loop(port); @@ -100,7 +169,9 @@ TEST(AsyncUnixTest, SignalWithValue) { // // Also, this test fails on Linux on mipsel. si_value comes back as zero. No one with a mips // machine wants to debug the problem but they demand a patch fixing it, so we disable the test. - // Sad. https://github.com/sandstorm-io/capnproto/issues/204 + // Sad. https://github.com/capnproto/capnproto/issues/204 + + if (BROKEN_QEMU) return; captureSignals(); UnixEventPort port; @@ -135,7 +206,9 @@ TEST(AsyncUnixTest, SignalWithPointerValue) { // // Also, this test fails on Linux on mipsel. si_value comes back as zero. No one with a mips // machine wants to debug the problem but they demand a patch fixing it, so we disable the test. - // Sad. https://github.com/sandstorm-io/capnproto/issues/204 + // Sad. https://github.com/capnproto/capnproto/issues/204 + + if (BROKEN_QEMU) return; captureSignals(); UnixEventPort port; @@ -162,6 +235,8 @@ TEST(AsyncUnixTest, SignalWithPointerValue) { #endif TEST(AsyncUnixTest, SignalsMultiListen) { + if (BROKEN_QEMU) return; + captureSignals(); UnixEventPort port; EventLoop loop(port); @@ -186,6 +261,8 @@ TEST(AsyncUnixTest, SignalsMultiListen) { // platform I'm assuming it's a Cygwin bug. TEST(AsyncUnixTest, SignalsMultiReceive) { + if (BROKEN_QEMU) return; + captureSignals(); UnixEventPort port; EventLoop loop(port); @@ -206,16 +283,24 @@ TEST(AsyncUnixTest, SignalsMultiReceive) { #endif // !__CYGWIN32__ TEST(AsyncUnixTest, SignalsAsync) { + if (BROKEN_QEMU) return; + captureSignals(); UnixEventPort port; EventLoop loop(port); WaitScope waitScope(loop); // Arrange for a signal to be sent from another thread. - pthread_t mainThread = pthread_self(); + pthread_t mainThread KJ_UNUSED = pthread_self(); Thread thread([&]() { delay(); +#if __APPLE__ && KJ_USE_KQUEUE + // MacOS kqueue only receives process-level signals and there's nothing much we can do about + // that. + kill(getpid(), SIGURG); +#else pthread_kill(mainThread, SIGURG); +#endif }); siginfo_t info = port.onSignal(SIGURG).wait(waitScope); @@ -366,6 +451,32 @@ TEST(AsyncUnixTest, ReadObserverMultiReceive) { promise2.wait(waitScope); } +TEST(AsyncUnixTest, ReadObserverAndSignals) { + // Get FD events while also waiting on a signal. This specifically exercises epoll_pwait() for + // FD events on Linux. + + captureSignals(); + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + auto signalPromise = port.onSignal(SIGIO); + + int pipefds[2]; + KJ_SYSCALL(pipe(pipefds)); + kj::AutoCloseFd infd(pipefds[0]), outfd(pipefds[1]); + + UnixEventPort::FdObserver observer(port, infd, UnixEventPort::FdObserver::OBSERVE_READ); + + KJ_SYSCALL(write(outfd, "foo", 3)); + + observer.whenBecomesReadable().wait(waitScope); + + KJ_EXPECT(!signalPromise.poll(waitScope)) + kill(getpid(), SIGIO); + KJ_EXPECT(signalPromise.poll(waitScope)) +} + TEST(AsyncUnixTest, ReadObserverAsync) { captureSignals(); UnixEventPort port; @@ -488,8 +599,9 @@ TEST(AsyncUnixTest, WriteObserver) { EXPECT_TRUE(writable); } -#if !__APPLE__ -// Disabled on macOS due to https://github.com/sandstorm-io/capnproto/issues/374. +#if !__APPLE__ && !(KJ_USE_KQUEUE && !defined(EVFILT_EXCEPT)) +// Disabled on macOS due to https://github.com/capnproto/capnproto/issues/374. +// Disabled on kqueue systems that lack EVFILT_EXCEPT because it doesn't work there. TEST(AsyncUnixTest, UrgentObserver) { // Verify that FdObserver correctly detects availability of out-of-band data. // Availability of out-of-band data is implementation-specific. @@ -773,14 +885,13 @@ struct TestChild { KJ_SYSCALL(::kill(KJ_REQUIRE_NONNULL(pid), signo)); } - KJ_DISALLOW_COPY(TestChild); + KJ_DISALLOW_COPY_AND_MOVE(TestChild); }; TEST(AsyncUnixTest, ChildProcess) { + if (BROKEN_QEMU) return; + captureSignals(); - UnixEventPort port; - EventLoop loop(port); - WaitScope waitScope(loop); // Block SIGTERM so that we can carefully un-block it in children. sigset_t sigs, oldsigs; @@ -789,6 +900,10 @@ TEST(AsyncUnixTest, ChildProcess) { KJ_SYSCALL(pthread_sigmask(SIG_BLOCK, &sigs, &oldsigs)); KJ_DEFER(KJ_SYSCALL(pthread_sigmask(SIG_SETMASK, &oldsigs, nullptr)) { break; }); + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + TestChild child1(port, 123); KJ_EXPECT(!child1.promise.poll(waitScope)); @@ -904,8 +1019,8 @@ KJ_TEST("UnixEventPort poll for signals") { KJ_EXPECT(!promise1.poll(waitScope)); KJ_EXPECT(!promise2.poll(waitScope)); - KJ_SYSCALL(raise(SIGURG)); - KJ_SYSCALL(raise(SIGIO)); + KJ_SYSCALL(kill(getpid(), SIGURG)); + KJ_SYSCALL(kill(getpid(), SIGIO)); port.wake(); KJ_EXPECT(port.poll()); @@ -962,6 +1077,48 @@ KJ_TEST("UnixEventPort can receive multiple queued instances of an RT signal") { } #endif +#if !(__APPLE__ && KJ_USE_KQUEUE) +KJ_TEST("UnixEventPort thread-specific signals") { + // Verify a signal directed to a thread is only received on the intended thread. + // + // MacOS kqueue only receives process-level signals and there's nothing much we can do about + // that, so this test won't work there. + + if (BROKEN_QEMU) return; + + captureSignals(); + + Vector> threads; + std::atomic readyCount(0); + std::atomic doneCount(0); + for (auto i KJ_UNUSED: kj::zeroTo(16)) { + threads.add(kj::heap([&]() noexcept { + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + readyCount.fetch_add(1, std::memory_order_relaxed); + port.onSignal(SIGIO).wait(waitScope); + doneCount.fetch_add(1, std::memory_order_relaxed); + })); + } + + do { + usleep(1000); + } while (readyCount.load(std::memory_order_relaxed) < 16); + + KJ_ASSERT(doneCount.load(std::memory_order_relaxed) == 0); + + uint count = 0; + for (uint i: {5, 14, 4, 6, 7, 11, 1, 3, 8, 0, 12, 9, 10, 15, 2, 13}) { + threads[i]->sendSignal(SIGIO); + threads[i] = nullptr; // wait for that one thread to exit + usleep(1000); + KJ_ASSERT(doneCount.load(std::memory_order_relaxed) == ++count); + } +} +#endif + } // namespace } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-unix.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-unix.c++ index 796f629182e..a8179ea5ae8 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-unix.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-unix.c++ @@ -35,8 +35,15 @@ #if KJ_USE_EPOLL #include -#include #include +#elif KJ_USE_KQUEUE +#include +#include +#include +#if !__APPLE__ && !__OpenBSD__ +// MacOS and OpenBSD are missing this, which means we have to do ugly hacks instead on those. +#define KJ_HAS_SIGTIMEDWAIT 1 +#endif #else #include #include @@ -52,7 +59,69 @@ namespace { int reservedSignal = SIGUSR1; bool tooLateToSetReserved = false; bool capturedChildExit = false; + +#if !KJ_USE_KQUEUE bool threadClaimedChildExits = false; +#endif + +} // namespace + +#if KJ_USE_EPOLL + +namespace { + +KJ_THREADLOCAL_PTR(UnixEventPort) threadEventPort = nullptr; +// This is set to the current UnixEventPort just before epoll_pwait(), then back to null after it +// returns. + +} // namespace + +void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { + // Since this signal handler is *only* called during `epoll_pwait()`, we aren't subject to the + // usual signal-safety concerns. We can treat this more like a callback. So, we can just call + // gotSignal() directly, no biggy. + + // Note that, if somehow the signal hanlder is invoked when *not* running `epoll_pwait()`, then + // `threadEventPort` will be null. We silently ignore the signal in this case. This should never + // happen in normal execution, so you might argue we should assert-fail instead. However: + // - We obviously can't throw from here, so we'd have to crash instead. + // - The Cloudflare Workers runtime relies on this no-op behavior for a certain hack. The hack + // in question involves unblocking a signal from the signal mask and relying on it to interrupt + // certain blocking syscalls, causing them to fail with EINTR. The hack does not need the + // handler to do anything except return in this case. The hacky code makes sure to restore the + // signal mask before returning to the event loop. + + UnixEventPort* current = threadEventPort; + if (current != nullptr) { + current->gotSignal(*siginfo); + } +} + +#elif KJ_USE_KQUEUE + +#if !KJ_HAS_SIGTIMEDWAIT +KJ_THREADLOCAL_PTR(siginfo_t) threadCapture = nullptr; +#endif + +void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { +#if KJ_HAS_SIGTIMEDWAIT + // This is never called because we use sigtimedwait() to dequeue the signal while it is still + // blocked, without running the signal handler. However, if we don't register a handler at all, + // and the default behavior is SIG_IGN, then the signal will be discarded before sigtimedwait() + // can receive it. +#else + // When sigtimedwait() isn't available, we use sigsuspend() and wait for the siginfo_t to be + // delivered to the signal handler. + siginfo_t* capture = threadCapture; + if (capture != nullptr) { + *capture = *siginfo; + } +#endif +} + +#else + +namespace { struct SignalCapture { sigjmp_buf jumpTo; @@ -83,10 +152,11 @@ struct SignalCapture { #endif }; -#if !KJ_USE_EPOLL // on Linux we'll use signalfd KJ_THREADLOCAL_PTR(SignalCapture) threadCapture = nullptr; -void signalHandler(int, siginfo_t* siginfo, void*) { +} // namespace + +void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { SignalCapture* capture = threadCapture; if (capture != nullptr) { capture->siginfo = *siginfo; @@ -104,43 +174,64 @@ void signalHandler(int, siginfo_t* siginfo, void*) { #endif } } -#endif -void registerSignalHandler(int signum) { +#endif // !KJ_USE_EPOLL && !KJ_USE_KQUEUE + +void UnixEventPort::registerSignalHandler(int signum) { + KJ_REQUIRE(signum != SIGBUS && signum != SIGFPE && signum != SIGILL && signum != SIGSEGV, + "this signal is raised by erroneous code execution; you cannot capture it into the event " + "loop"); + tooLateToSetReserved = true; + // Block the signal from being delivered most of the time. We'll explicitly unblock it when we + // want to receive it. sigset_t mask; KJ_SYSCALL(sigemptyset(&mask)); KJ_SYSCALL(sigaddset(&mask, signum)); KJ_SYSCALL(pthread_sigmask(SIG_BLOCK, &mask, nullptr)); -#if !KJ_USE_EPOLL // on Linux we'll use signalfd + // Register the signal handler which should be invoked when we explicitly unblock the signal. struct sigaction action; memset(&action, 0, sizeof(action)); action.sa_sigaction = &signalHandler; - KJ_SYSCALL(sigfillset(&action.sa_mask)); action.sa_flags = SA_SIGINFO; + + // Set up the signal mask applied while the signal handler runs. We want to block all other + // signals from being raised during the handler, with the exception of the four "crash" signals, + // which realistically can't be blocked. + KJ_SYSCALL(sigfillset(&action.sa_mask)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGBUS)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGFPE)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGILL)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGSEGV)); + KJ_SYSCALL(sigaction(signum, &action, nullptr)); -#endif } -#if !KJ_USE_EPOLL && !KJ_USE_PIPE_FOR_WAKEUP -void registerReservedSignal() { +#if !KJ_USE_EPOLL && !KJ_USE_KQUEUE && !KJ_USE_PIPE_FOR_WAKEUP +void UnixEventPort::registerReservedSignal() { registerSignalHandler(reservedSignal); } #endif -void ignoreSigpipe() { +void UnixEventPort::ignoreSigpipe() { // We disable SIGPIPE because users of UnixEventPort almost certainly don't want it. - while (signal(SIGPIPE, SIG_IGN) == SIG_ERR) { - int error = errno; - if (error != EINTR) { - KJ_FAIL_SYSCALL("signal(SIGPIPE, SIG_IGN)", error); + // + // We've observed that when starting many threads at the same time, this can cause some + // contention on the kernel's signal handler table lock, so we try to run it only once. + static bool once KJ_UNUSED = []() { + while (signal(SIGPIPE, SIG_IGN) == SIG_ERR) { + int error = errno; + if (error != EINTR) { + KJ_FAIL_SYSCALL("signal(SIGPIPE, SIG_IGN)", error); + } } - } + return true; + }(); } -} // namespace +#if !KJ_USE_KQUEUE // kqueue systems handle child processes differently struct UnixEventPort::ChildSet { std::map waiters; @@ -264,6 +355,8 @@ Promise UnixEventPort::onSignal(int signum) { return newAdaptedPromise(*this, signum); } +#endif // !KJ_USE_KQUEUE + void UnixEventPort::captureSignal(int signum) { if (reservedSignal == SIGUSR1) { KJ_REQUIRE(signum != SIGUSR1, @@ -287,6 +380,8 @@ void UnixEventPort::setReservedSignal(int signum) { reservedSignal = signum; } +#if !KJ_USE_KQUEUE + void UnixEventPort::gotSignal(const siginfo_t& siginfo) { // If onChildExit() has been called and this is SIGCHLD, check for child exits. KJ_IF_MAYBE(cs, childSet) { @@ -308,39 +403,35 @@ void UnixEventPort::gotSignal(const siginfo_t& siginfo) { } } +#endif // !KJ_USE_KQUEUE + #if KJ_USE_EPOLL // ======================================================================================= // epoll FdObserver implementation UnixEventPort::UnixEventPort() : clock(systemPreciseMonotonicClock()), - timerImpl(clock.now()), - epollFd(-1), - signalFd(-1), - eventFd(-1) { + timerImpl(clock.now()) { ignoreSigpipe(); int fd; KJ_SYSCALL(fd = epoll_create1(EPOLL_CLOEXEC)); epollFd = AutoCloseFd(fd); - memset(&signalFdSigset, 0, sizeof(signalFdSigset)); - - KJ_SYSCALL(sigemptyset(&signalFdSigset)); - KJ_SYSCALL(fd = signalfd(-1, &signalFdSigset, SFD_NONBLOCK | SFD_CLOEXEC)); - signalFd = AutoCloseFd(fd); - KJ_SYSCALL(fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); eventFd = AutoCloseFd(fd); - struct epoll_event event; memset(&event, 0, sizeof(event)); event.events = EPOLLIN; event.data.u64 = 0; - KJ_SYSCALL(epoll_ctl(epollFd, EPOLL_CTL_ADD, signalFd, &event)); - event.data.u64 = 1; KJ_SYSCALL(epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event)); + + // Get the current signal mask, from which we'll compute the appropriate mask to pass to + // epoll_pwait() on each loop. (We explicitly memset to 0 first to make sure we can compare + // this against another mask with memcmp() for debug purposes.) + memset(&originalMask, 0, sizeof(originalMask)); + KJ_SYSCALL(sigprocmask(0, nullptr, &originalMask)); } UnixEventPort::~UnixEventPort() noexcept(false) { @@ -443,208 +534,592 @@ Promise UnixEventPort::FdObserver::whenWriteDisconnected() { return kj::mv(paf.promise); } +void UnixEventPort::wake() const { + uint64_t one = 1; + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = write(eventFd, &one, sizeof(one))); + KJ_ASSERT(n < 0 || n == sizeof(one)); +} + bool UnixEventPort::wait() { - return doEpollWait( - timerImpl.timeoutToNextEvent(clock.now(), MILLISECONDS, int(maxValue)) +#ifdef KJ_DEBUG + // In debug mode, verify the current signal mask matches the original. + { + sigset_t currentMask; + memset(¤tMask, 0, sizeof(currentMask)); + KJ_SYSCALL(sigprocmask(0, nullptr, ¤tMask)); + if (memcmp(¤tMask, &originalMask, sizeof(currentMask)) != 0) { + kj::Vector changes; + for (int i = 0; i <= SIGRTMAX; i++) { + if (sigismember(¤tMask, i) && !sigismember(&originalMask, i)) { + changes.add(kj::str("signal #", i, " (", strsignal(i), ") was added")); + } else if (!sigismember(¤tMask, i) && sigismember(&originalMask, i)) { + changes.add(kj::str("signal #", i, " (", strsignal(i), ") was removed")); + } + } + + KJ_FAIL_REQUIRE( + "Signal mask has changed since UnixEventPort was constructed. You are required to " + "ensure that whenever control returns to the event loop, the signal mask is the same " + "as it was when UnixEventPort was created. In non-debug builds, this check is skipped, " + "and this situation may instead lead to unexpected results. In particular, while the " + "system is waiting for I/O events, the signal mask may be reverted to what it was at " + "construction time, ignoring your subsequent changes.", changes); + } + } +#endif + + int timeout = timerImpl.timeoutToNextEvent(clock.now(), MILLISECONDS, int(maxValue)) .map([](uint64_t t) -> int { return t; }) - .orDefault(-1)); + .orDefault(-1); + + struct epoll_event events[16]; + int n; + if (signalHead != nullptr || childSet != nullptr) { + // We are interested in some signals. Use epoll_pwait(). + // + // Note: Once upon a time, we used signalfd for this. However, this turned out to be more + // trouble than it was worth. Some problems with signalfd: + // - It required opening an additional file descriptor per thread. + // - If the set of interesting signals changed, the signalfd would have to be updated before + // calling epoll_wait(), which was an extra syscall. + // - When a signal arrives, it requires extra syscalls to read the signal info from the + // signalfd, as well as code to translate from signalfd_siginfo to siginfo_t, which are + // different for some reason. + // - signalfd suffers from surprising lock contention during epoll_wait or when the signalfd's + // mask is updated in programs with many threads. Because the lock is a spinlock, this + // could consume exorbitant CPU. + // - When a signalfd is in an epoll, it will be flagged readable based on signals which are + // pending in the process/thread which called epoll_ctl_add() to register the signalfd. + // This is mostly fine for our usage, except that it breaks one useful case that otherwise + // works: many servers are designed to "daemonize" themselves by fork()ing and then having + // the parent process exit while the child thread lives on. In this case, if a UnixEventPort + // had been created before daemonizing, signal handling would be forever broken in the child. + + sigset_t waitMask = originalMask; + + // Unblock the signals we care about. + { + auto ptr = signalHead; + while (ptr != nullptr) { + KJ_SYSCALL(sigdelset(&waitMask, ptr->signum)); + ptr = ptr->next; + } + if (childSet != nullptr) { + KJ_SYSCALL(sigdelset(&waitMask, SIGCHLD)); + } + } + + threadEventPort = this; + n = epoll_pwait(epollFd, events, kj::size(events), timeout, &waitMask); + threadEventPort = nullptr; + } else { + // Not waiting on any signals. Regular epoll_wait() will be fine. + n = epoll_wait(epollFd, events, kj::size(events), timeout); + } + + if (n < 0) { + int error = errno; + if (error == EINTR) { + // We received a singal. The signal handler may have queued an event to the event loop. Even + // if it didn't, we can't simply restart the epoll call because we need to recompute the + // timeout. Instead, we pretend epoll_wait() returned zero events. This will cause the event + // loop to spin once, decide it has nothing to do, recompute timeouts, then return to waiting. + n = 0; + } else { + KJ_FAIL_SYSCALL("epoll_pwait()", error); + } + } + + return processEpollEvents(events, n); +} + +bool UnixEventPort::processEpollEvents(struct epoll_event events[], int n) { + bool woken = false; + + for (int i = 0; i < n; i++) { + if (events[i].data.u64 == 0) { + // Someone called wake() from another thread. Consume the event. + uint64_t value; + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = read(eventFd, &value, sizeof(value))); + KJ_ASSERT(n < 0 || n == sizeof(value)); + + // We were woken. Need to return true. + woken = true; + } else { + FdObserver* observer = reinterpret_cast(events[i].data.ptr); + observer->fire(events[i].events); + } + } + + timerImpl.advanceTo(clock.now()); + + return woken; } bool UnixEventPort::poll() { - return doEpollWait(0); + // Unfortunately, epoll_pwait() with a timeout of zero will never deliver actually deliver any + // pending signals. Therefore, we need a completely different approach to poll for signals. We + // might as well use regular epoll_wait() in this case, too, to save the kernel some effort. + + if (signalHead != nullptr || childSet != nullptr) { + // Use sigtimedwait() to poll for signals. + + // Construct a sigset of all signals we are interested in. + sigset_t sigset; + KJ_SYSCALL(sigemptyset(&sigset)); + uint count = 0; + + { + auto ptr = signalHead; + while (ptr != nullptr) { + KJ_SYSCALL(sigaddset(&sigset, ptr->signum)); + ++count; + ptr = ptr->next; + } + if (childSet != nullptr) { + KJ_SYSCALL(sigaddset(&sigset, SIGCHLD)); + ++count; + } + } + + // While that set is non-empty, poll for signals. + while (count > 0) { + struct timespec timeout; + timeout.tv_sec = 0; + timeout.tv_nsec = 0; + + siginfo_t siginfo; + int n; + KJ_NONBLOCKING_SYSCALL(n = sigtimedwait(&sigset, &siginfo, &timeout)); + if (n < 0) break; // EAGAIN: no signals in set are raised + + KJ_ASSERT(n == siginfo.si_signo); + gotSignal(siginfo); + + // Remove that signal from the set so we don't receive it again, but keep checking for others + // if there are any. + KJ_SYSCALL(sigdelset(&sigset, n)); + --count; + } + } + + struct epoll_event events[16]; + int n; + KJ_SYSCALL(n = epoll_wait(epollFd, events, kj::size(events), 0)); + + return processEpollEvents(events, n); } -void UnixEventPort::wake() const { - uint64_t one = 1; - ssize_t n; - KJ_NONBLOCKING_SYSCALL(n = write(eventFd, &one, sizeof(one))); - KJ_ASSERT(n < 0 || n == sizeof(one)); +#elif KJ_USE_KQUEUE +// ======================================================================================= +// kqueue FdObserver implementation + +UnixEventPort::UnixEventPort() + : clock(systemPreciseMonotonicClock()), + timerImpl(clock.now()) { + ignoreSigpipe(); + + int fd; + KJ_SYSCALL(fd = kqueue()); + kqueueFd = AutoCloseFd(fd); + + // NetBSD has kqueue1() which can set CLOEXEC atomically, but FreeBSD, MacOS, and others don't + // have this... oh well. + KJ_SYSCALL(fcntl(kqueueFd, F_SETFD, FD_CLOEXEC)); + + // Register the EVFILT_USER event used by wake(). + struct kevent event; + EV_SET(&event, 0, EVFILT_USER, EV_ADD | EV_CLEAR, 0, 0, nullptr); + KJ_SYSCALL(kevent(kqueueFd, &event, 1, nullptr, 0, nullptr)); } -static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) { - // Unfortunately, siginfo_t is mostly a big union and the correct set of fields to fill in - // depends on the type of signal. OTOH, signalfd_siginfo is a flat struct that expands all - // siginfo_t's union fields out to be non-overlapping. We can't just copy all the fields over - // because of the unions; we have to carefully figure out which fields are appropriate to fill - // in for this signal. Ick. - - siginfo_t result; - memset(&result, 0, sizeof(result)); - - result.si_signo = siginfo.ssi_signo; - result.si_errno = siginfo.ssi_errno; - result.si_code = siginfo.ssi_code; - - if (siginfo.ssi_code > 0) { - // Signal originated from the kernel. The structure of the siginfo depends primarily on the - // signal number. - - switch (siginfo.ssi_signo) { - case SIGCHLD: - result.si_pid = siginfo.ssi_pid; - result.si_uid = siginfo.ssi_uid; - result.si_status = siginfo.ssi_status; - result.si_utime = siginfo.ssi_utime; - result.si_stime = siginfo.ssi_stime; - break; +UnixEventPort::~UnixEventPort() noexcept(false) {} + +UnixEventPort::FdObserver::FdObserver(UnixEventPort& eventPort, int fd, uint flags) + : eventPort(eventPort), fd(fd), flags(flags) { + struct kevent events[3]; + int nevents = 0; - case SIGILL: - case SIGFPE: - case SIGSEGV: - case SIGBUS: - case SIGTRAP: - result.si_addr = reinterpret_cast(static_cast(siginfo.ssi_addr)); -#ifdef si_trapno - result.si_trapno = siginfo.ssi_trapno; + if (flags & OBSERVE_URGENT) { +#ifdef EVFILT_EXCEPT + EV_SET(&events[nevents++], fd, EVFILT_EXCEPT, EV_ADD | EV_CLEAR, NOTE_OOB, 0, this); +#else + // TODO(someday): Can we support this without reverting to poll()? + // Related: https://sandstorm.io/news/2015-04-08-osx-security-bug + KJ_FAIL_ASSERT("kqueue() on this system doesn't support EVFILT_EXCEPT (for OBSERVE_URGENT). " + "If you really need to observe OOB events, compile KJ (and your application) with " + "-DKJ_USE_KQUEUE=0 to disable use of kqueue()."); #endif -#ifdef si_addr_lsb - // ssi_addr_lsb is defined as coming immediately after ssi_addr in the kernel headers but - // apparently the userspace headers were never updated. So we do a pointer hack. :( - result.si_addr_lsb = *reinterpret_cast(&siginfo.ssi_addr + 1); + } + if (flags & OBSERVE_READ) { + EV_SET(&events[nevents++], fd, EVFILT_READ, EV_ADD | EV_CLEAR, 0, 0, this); + } + if (flags & OBSERVE_WRITE) { + EV_SET(&events[nevents++], fd, EVFILT_WRITE, EV_ADD | EV_CLEAR, 0, 0, this); + } + + KJ_SYSCALL(kevent(eventPort.kqueueFd, events, nevents, nullptr, 0, nullptr)); +} + +UnixEventPort::FdObserver::~FdObserver() noexcept(false) { + struct kevent events[3]; + int nevents = 0; + + if (flags & OBSERVE_URGENT) { +#ifdef EVFILT_EXCEPT + EV_SET(&events[nevents++], fd, EVFILT_EXCEPT, EV_DELETE, 0, 0, nullptr); #endif + } + if (flags & OBSERVE_READ) { + EV_SET(&events[nevents++], fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr); + } + if ((flags & OBSERVE_WRITE) || hupFulfiller != nullptr) { + EV_SET(&events[nevents++], fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + } + + // TODO(perf): Should we delay unregistration of events until the next time kqueue() is invoked? + // We can't delay registrations since it could lead to missed events, but we could delay + // unregistration safely. However, we'd have to be very careful about the possibility that + // the same FD is re-registered later. + KJ_SYSCALL_HANDLE_ERRORS(kevent(eventPort.kqueueFd, events, nevents, nullptr, 0, nullptr)) { + case ENOENT: + // In the specific case of unnamed pipes, when read end of the pipe is destroyed, FreeBSD + // seems to unregister the events on the write end automatically. Subsequently trying to + // remove them then produces ENOENT. Let's ignore this. break; + default: + KJ_FAIL_SYSCALL("kevent(remove events)", error); + } +} - case SIGIO: - static_assert(SIGIO == SIGPOLL, "SIGIO != SIGPOLL?"); +void UnixEventPort::FdObserver::fire(struct kevent event) { + switch (event.filter) { + case EVFILT_READ: + if (event.flags & EV_EOF) { + atEnd = true; + } else { + atEnd = false; + } + + KJ_IF_MAYBE(f, readFulfiller) { + f->get()->fulfill(); + readFulfiller = nullptr; + } + break; - // Note: Technically, code can arrange for SIGIO signals to be delivered with a signal number - // other than SIGIO. AFAICT there is no way for us to detect this in the siginfo. Luckily - // SIGIO is totally obsoleted by epoll so it shouldn't come up. + case EVFILT_WRITE: + if (event.flags & EV_EOF) { + // EOF on write indicates disconnect. + KJ_IF_MAYBE(f, hupFulfiller) { + f->get()->fulfill(); + hupFulfiller = nullptr; + if (!(flags & OBSERVE_WRITE)) { + // We were only observing writes to get the disconnect event. Stop observing now. + struct kevent rmEvent; + EV_SET(&rmEvent, fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + KJ_SYSCALL_HANDLE_ERRORS(kevent(eventPort.kqueueFd, &rmEvent, 1, nullptr, 0, nullptr)) { + case ENOENT: + // In the specific case of unnamed pipes, when read end of the pipe is destroyed, + // FreeBSD seems to unregister the events on the write end automatically. + // Subsequently trying to remove them then produces ENOENT. Let's ignore this. + break; + default: + KJ_FAIL_SYSCALL("kevent(remove events)", error); + } + } + } + } - result.si_band = siginfo.ssi_band; - result.si_fd = siginfo.ssi_fd; + KJ_IF_MAYBE(f, writeFulfiller) { + f->get()->fulfill(); + writeFulfiller = nullptr; + } break; - case SIGSYS: - // Apparently SIGSYS's fields are not available in signalfd_siginfo? +#ifdef EVFILT_EXCEPT + case EVFILT_EXCEPT: + KJ_IF_MAYBE(f, urgentFulfiller) { + f->get()->fulfill(); + urgentFulfiller = nullptr; + } break; - } +#endif + } +} - } else { - // Signal originated from userspace. The sender could specify whatever signal number they - // wanted. The structure of the signal is determined by the API they used, which is identified - // by SI_CODE. - - switch (siginfo.ssi_code) { - case SI_USER: - case SI_TKILL: - // kill(), tkill(), or tgkill(). - result.si_pid = siginfo.ssi_pid; - result.si_uid = siginfo.ssi_uid; - break; +Promise UnixEventPort::FdObserver::whenBecomesReadable() { + KJ_REQUIRE(flags & OBSERVE_READ, "FdObserver was not set to observe reads."); - case SI_QUEUE: - case SI_MESGQ: - case SI_ASYNCIO: - default: - result.si_pid = siginfo.ssi_pid; - result.si_uid = siginfo.ssi_uid; - - // This is awkward. In siginfo_t, si_ptr and si_int are in a union together. In - // signalfd_siginfo, they are not. We don't really know whether the app intended to send - // an int or a pointer. Presumably since the pointer is always larger than the int, if - // we write the pointer, we'll end up with the right value for the int? Presumably the - // two fields of signalfd_siginfo are actually extracted from one of these unions - // originally, so actually contain redundant data? Better write some tests... - // - // Making matters even stranger, siginfo.ssi_ptr is 64-bit even on 32-bit systems, and - // it appears that instead of doing the obvious thing by casting the pointer value to - // 64 bits, the kernel actually memcpy()s the 32-bit value into the 64-bit space. As - // a result, on big-endian 32-bit systems, the original pointer value ends up in the - // *upper* 32 bits of siginfo.ssi_ptr, which is totally weird. We play along and use - // a memcpy() on our end too, to get the right result on all platforms. - memcpy(&result.si_ptr, &siginfo.ssi_ptr, sizeof(result.si_ptr)); - break; + auto paf = newPromiseAndFulfiller(); + readFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); +} - case SI_TIMER: - result.si_timerid = siginfo.ssi_tid; - result.si_overrun = siginfo.ssi_overrun; +Promise UnixEventPort::FdObserver::whenBecomesWritable() { + KJ_REQUIRE(flags & OBSERVE_WRITE, "FdObserver was not set to observe writes."); - // Again with this weirdness... - result.si_ptr = reinterpret_cast(static_cast(siginfo.ssi_ptr)); - break; - } + auto paf = newPromiseAndFulfiller(); + writeFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); +} + +Promise UnixEventPort::FdObserver::whenUrgentDataAvailable() { + KJ_REQUIRE(flags & OBSERVE_URGENT, + "FdObserver was not set to observe availability of urgent data."); + + auto paf = newPromiseAndFulfiller(); + urgentFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); +} + +Promise UnixEventPort::FdObserver::whenWriteDisconnected() { + if (!(flags & OBSERVE_WRITE) && hupFulfiller == nullptr) { + // We aren't observing writes, but we need to if we want to detect disconnects. + struct kevent event; + EV_SET(&event, fd, EVFILT_WRITE, EV_ADD | EV_CLEAR, 0, 0, this); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); } - return result; + auto paf = newPromiseAndFulfiller(); + hupFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); } -bool UnixEventPort::doEpollWait(int timeout) { - sigset_t newMask; - memset(&newMask, 0, sizeof(newMask)); - sigemptyset(&newMask); +class UnixEventPort::SignalPromiseAdapter { +public: + inline SignalPromiseAdapter(PromiseFulfiller& fulfiller, + UnixEventPort& eventPort, int signum) + : eventPort(eventPort), signum(signum), fulfiller(fulfiller) { + struct kevent event; + EV_SET(&event, signum, EVFILT_SIGNAL, EV_ADD | EV_CLEAR, 0, 0, this); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + + // We must check for the signal now in case it was delivered previously and is currently in + // the blocked set. See comment in tryConsumeSignal(). (To avoid the race condition, we must + // check *after* having registered the kevent!) + tryConsumeSignal(); + } - { - auto ptr = signalHead; - while (ptr != nullptr) { - sigaddset(&newMask, ptr->signum); - ptr = ptr->next; + ~SignalPromiseAdapter() noexcept(false) { + // Unregister the event. This is important because it contains a pointer to this object which + // we don't want to see again. + struct kevent event; + EV_SET(&event, signum, EVFILT_SIGNAL, EV_DELETE, 0, 0, nullptr); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + } + + void tryConsumeSignal() { + // Unfortunately KJ's signal semantics are not a great fit for kqueue. In particular, KJ + // assumes that if no threads are waiting for a signal, it'll remain blocked until some + // thread actually calls `onSignal()` to receive it. kqueue, however, doesn't care if a signal + // is blocked -- the kqueue event will still be delivered. So, when `onSignal()` is called + // we will need to check if the signal is already queued; it's too late to ask kqueue() to + // tell us this. + // + // Alternatively we could maybe fix this by having every thread's kqueue wait on all captured + // signals all the time, but this would result in a thundering herd on any signal even if only + // one thread has actually registered interest. + // + // Another problem is per-thread signals, delivered with pthread_kill(). On FreeBSD, it appears + // a pthread_kill will wake up all kqueues in the process waiting on the particular signal, + // even if they are not associated with the target thread (kqueues don't really have any + // association with threads anyway). Worse, though, on MacOS, pthread_kill() doesn't wake + // kqueues at all. In fact, it appears they made it this way in 10.14, which broke stuff: + // https://github.com/libevent/libevent/issues/765 + // + // So, we have to: + // - Block signals normally. + // - Poll for a specific signal using sigtimedwait() or similar. + // - Use kqueue only as a hint to tell us when polling might be a good idea. + // - On MacOS, live with per-thread signals being broken I guess? + + // Anyway, this method here tries to have the signal delivered to this thread. + + if (fulfiller.isWaiting()) { +#if KJ_HAS_SIGTIMEDWAIT + sigset_t mask; + KJ_SYSCALL(sigemptyset(&mask)); + KJ_SYSCALL(sigaddset(&mask, signum)); + siginfo_t result; + struct timespec timeout; + memset(&timeout, 0, sizeof(timeout)); + + KJ_SYSCALL_HANDLE_ERRORS(sigtimedwait(&mask, &result, &timeout)) { + case EAGAIN: + // Signal was not queued. + return; + default: + KJ_FAIL_SYSCALL("sigtimedwait", error); + } + + fulfiller.fulfill(kj::mv(result)); +#else + // This platform doesn't appear to have sigtimedwait(). Ugh! We are forced to do two separate + // syscalls to see if the signal is pending, and then, if so, wait for it. There is an + // inherent race condition since the signal could be dequeued in another thread concurrently. + // We will try to work around that by locking a global mutex, so at least this code doesn't + // race against itself. + static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER; + pthread_mutex_lock(&mut); + KJ_DEFER(pthread_mutex_unlock(&mut)); + + sigset_t mask; + KJ_SYSCALL(sigpending(&mask)); + int isset; + KJ_SYSCALL(isset = sigismember(&mask, signum)); + if (isset) { + KJ_SYSCALL(sigfillset(&mask)); + KJ_SYSCALL(sigdelset(&mask, signum)); + siginfo_t info; + memset(&info, 0, sizeof(info)); + threadCapture = &info; + KJ_DEFER(threadCapture = nullptr); + int result = sigsuspend(&mask); + KJ_ASSERT(result < 0 && errno == EINTR, "sigsuspend() didn't EINTR?", result, errno); + KJ_ASSERT(info.si_signo == signum); + fulfiller.fulfill(kj::mv(info)); + } +#endif } - if (childSet != nullptr) { - sigaddset(&newMask, SIGCHLD); + } + + UnixEventPort& eventPort; + int signum; + PromiseFulfiller& fulfiller; +}; + +Promise UnixEventPort::onSignal(int signum) { + KJ_REQUIRE(signum != SIGCHLD || !capturedChildExit, + "can't call onSigal(SIGCHLD) when kj::UnixEventPort::captureChildExit() has been called"); + + return newAdaptedPromise(*this, signum); +} + +class UnixEventPort::ChildExitPromiseAdapter { +public: + inline ChildExitPromiseAdapter(PromiseFulfiller& fulfiller, + UnixEventPort& eventPort, Maybe& pid) + : eventPort(eventPort), pid(pid), fulfiller(fulfiller) { + pid_t p = KJ_ASSERT_NONNULL(pid); + + struct kevent event; + EV_SET(&event, p, EVFILT_PROC, EV_ADD | EV_CLEAR, NOTE_EXIT, 0, this); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + + // Check for race where child had already exited before the event was waiting. + tryConsumeChild(); + } + + ~ChildExitPromiseAdapter() noexcept(false) { + KJ_IF_MAYBE(p, pid) { + // The process has not been reaped. The promise must have been canceled. So, we're still + // registered with the kqueue. We'd better unregister because the kevent points back to this + // object. + struct kevent event; + EV_SET(&event, *p, EVFILT_PROC, EV_DELETE, 0, 0, nullptr); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + + // We leak the zombie process here. The caller is responsible for doing its own waitpid(). } } - if (memcmp(&newMask, &signalFdSigset, sizeof(newMask)) != 0) { - // Apparently we're not waiting on the same signals as last time. Need to update the signal - // FD's mask. - signalFdSigset = newMask; - KJ_SYSCALL(signalfd(signalFd, &signalFdSigset, SFD_NONBLOCK | SFD_CLOEXEC)); + void tryConsumeChild() { + // Even though kqueue delivers the exit status to us, we still need to wait on the pid to + // clear the zombie. We can't set SIGCHLD to SIG_IGN to ignore this because it creates a race + // condition. + + KJ_IF_MAYBE(p, pid) { + int status; + pid_t result; + KJ_SYSCALL(result = waitpid(*p, &status, WNOHANG)); + if (result != 0) { + KJ_ASSERT(result == *p); + + // NOTE: The proc is automatically unregsitered from the kqueue on exit, so we should NOT + // attempt to unregister it here. + + pid = nullptr; + fulfiller.fulfill(kj::mv(status)); + } + } } - struct epoll_event events[16]; - int n = epoll_wait(epollFd, events, kj::size(events), timeout); + UnixEventPort& eventPort; + Maybe& pid; + PromiseFulfiller& fulfiller; +}; + +Promise UnixEventPort::onChildExit(Maybe& pid) { + KJ_REQUIRE(capturedChildExit, + "must call UnixEventPort::captureChildExit() to use onChildExit()."); + + return kj::newAdaptedPromise(*this, pid); +} + +void UnixEventPort::captureChildExit() { + capturedChildExit = true; +} + +void UnixEventPort::wake() const { + // Trigger our user event. + struct kevent event; + EV_SET(&event, 0, EVFILT_USER, 0, NOTE_TRIGGER, 0, nullptr); + KJ_SYSCALL(kevent(kqueueFd, &event, 1, nullptr, 0, nullptr)); +} + +bool UnixEventPort::doKqueueWait(struct timespec* timeout) { + struct kevent events[16]; + int n = kevent(kqueueFd, nullptr, 0, events, kj::size(events), timeout); + if (n < 0) { int error = errno; if (error == EINTR) { - // We can't simply restart the epoll call because we need to recompute the timeout. Instead, - // we pretend epoll_wait() returned zero events. This will cause the event loop to spin once, - // decide it has nothing to do, recompute timeouts, then return to waiting. + // We received a singal. The signal handler may have queued an event to the event loop. Even + // if it didn't, we can't simply restart the kevent call because we need to recompute the + // timeout. Instead, we pretend kevent() returned zero events. This will cause the event + // loop to spin once, decide it has nothing to do, recompute timeouts, then return to waiting. n = 0; } else { - KJ_FAIL_SYSCALL("epoll_wait()", error); + KJ_FAIL_SYSCALL("kevent()", error); } } bool woken = false; for (int i = 0; i < n; i++) { - if (events[i].data.u64 == 0) { - for (;;) { - struct signalfd_siginfo siginfo; - ssize_t n; - KJ_NONBLOCKING_SYSCALL(n = read(signalFd, &siginfo, sizeof(siginfo))); - if (n < 0) break; // no more signals - - KJ_ASSERT(n == sizeof(siginfo)); - - gotSignal(toRegularSiginfo(siginfo)); - -#ifdef SIGRTMIN - if (siginfo.ssi_signo >= SIGRTMIN) { - // This is an RT signal. There could be multiple copies queued. We need to remove it from - // the signalfd's signal mask before we continue, to avoid accidentally reading and - // discarding the extra copies. - // TODO(perf): If high throughput of RT signals is desired then perhaps we should read - // them all into userspace and queue them here. Maybe we even need a better interface - // than onSignal() for receiving high-volume RT signals. - KJ_SYSCALL(sigdelset(&signalFdSigset, siginfo.ssi_signo)); - KJ_SYSCALL(signalfd(signalFd, &signalFdSigset, SFD_NONBLOCK | SFD_CLOEXEC)); - } + switch (events[i].filter) { +#ifdef EVFILT_EXCEPT + case EVFILT_EXCEPT: #endif + case EVFILT_READ: + case EVFILT_WRITE: { + FdObserver* observer = reinterpret_cast(events[i].udata); + observer->fire(events[i]); + break; } - } else if (events[i].data.u64 == 1) { - // Someone called wake() from another thread. Consume the event. - uint64_t value; - ssize_t n; - KJ_NONBLOCKING_SYSCALL(n = read(eventFd, &value, sizeof(value))); - KJ_ASSERT(n < 0 || n == sizeof(value)); - // We were woken. Need to return true. - woken = true; - } else { - FdObserver* observer = reinterpret_cast(events[i].data.ptr); - observer->fire(events[i].events); + case EVFILT_SIGNAL: { + SignalPromiseAdapter* observer = reinterpret_cast(events[i].udata); + observer->tryConsumeSignal(); + break; + } + + case EVFILT_PROC: { + ChildExitPromiseAdapter* observer = + reinterpret_cast(events[i].udata); + observer->tryConsumeChild(); + break; + } + + case EVFILT_USER: + // Someone called wake() from another thread. + woken = true; + break; + + default: + KJ_FAIL_ASSERT("unexpected EVFILT", events[i].filter); } } @@ -653,7 +1128,24 @@ bool UnixEventPort::doEpollWait(int timeout) { return woken; } -#else // KJ_USE_EPOLL +bool UnixEventPort::wait() { + KJ_IF_MAYBE(t, timerImpl.timeoutToNextEvent(clock.now(), NANOSECONDS, int(maxValue))) { + struct timespec timeout; + timeout.tv_sec = *t / 1'000'000'000; + timeout.tv_nsec = *t % 1'000'000'000; + return doKqueueWait(&timeout); + } else { + return doKqueueWait(nullptr); + } +} + +bool UnixEventPort::poll() { + struct timespec timeout; + memset(&timeout, 0, sizeof(timeout)); + return doKqueueWait(&timeout); +} + +#else // KJ_USE_EPOLL, else KJ_USE_KQUEUE // ======================================================================================= // Traditional poll() FdObserver implementation. @@ -1089,7 +1581,7 @@ void UnixEventPort::wake() const { #endif } -#endif // KJ_USE_EPOLL, else +#endif // KJ_USE_EPOLL, else KJ_USE_KQUEUE, else } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-unix.h b/libs/EXTERNAL/capnproto/c++/src/kj/async-unix.h index 63fe92790f0..665305ea70c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-unix.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-unix.h @@ -27,15 +27,24 @@ #include "async.h" #include "timer.h" -#include "vector.h" -#include "io.h" +#include +#include #include KJ_BEGIN_HEADER -#if __linux__ && !__BIONIC__ && !defined(KJ_USE_EPOLL) -// Default to epoll on Linux, except on Bionic (Android) which doesn't have signalfd.h. +#if !defined(KJ_USE_EPOLL) && !defined(KJ_USE_KQUEUE) +#if __linux__ +// Default to epoll on Linux. #define KJ_USE_EPOLL 1 +#elif __APPLE__ || __FreeBSD__ || __OpenBSD__ || __NetBSD__ || __DragonFly__ +// MacOS and BSDs prefer kqueue() for event notification. +#define KJ_USE_KQUEUE 1 +#endif +#endif + +#if KJ_USE_EPOLL && KJ_USE_KQUEUE +#error "Both KJ_USE_EPOLL and KJ_USE_KQUEUE are set. Please choose only one of these." #endif #if __CYGWIN__ && !defined(KJ_USE_PIPE_FOR_WAKEUP) @@ -46,6 +55,13 @@ KJ_BEGIN_HEADER #define KJ_USE_PIPE_FOR_WAKEUP 1 #endif +#if KJ_USE_EPOLL +struct epoll_event; +#elif KJ_USE_KQUEUE +struct kevent; +struct timespec; +#endif + namespace kj { class UnixEventPort: public EventPort { @@ -55,20 +71,26 @@ class UnixEventPort: public EventPort { // The implementation uses `poll()` or possibly a platform-specific API (e.g. epoll, kqueue). // To also wait on signals without race conditions, the implementation may block signals until // just before `poll()` while using a signal handler which `siglongjmp()`s back to just before - // the signal was unblocked, or it may use a nicer platform-specific API like signalfd. + // the signal was unblocked, or it may use a nicer platform-specific API. // // The implementation reserves a signal for internal use. By default, it uses SIGUSR1. If you // need to use SIGUSR1 for something else, you must offer a different signal by calling - // setReservedSignal() at startup. + // setReservedSignal() at startup. (On Linux, no signal is reserved; eventfd is used instead.) // // WARNING: A UnixEventPort can only be used in the thread and process that created it. In // particular, note that after a fork(), a UnixEventPort created in the parent process will // not work correctly in the child, even if the parent ceases to use its copy. In particular // note that this means that server processes which daemonize themselves at startup must wait // until after daemonization to create a UnixEventPort. + // + // TODO(cleanup): The above warning is no longer accurate -- daemonizing after creating a + // UnixEventPort should now work since we no longer use signalfd. But do we want to commit to + // keeping it that way? Note it's still unsafe to fork() and then use UnixEventPort from both + // processes! public: UnixEventPort(); + ~UnixEventPort() noexcept(false); class FdObserver; @@ -84,6 +106,15 @@ class UnixEventPort: public EventPort { // process-wide signal by only calling `onSignal()` on that thread's event loop. // // The result of waiting on the same signal twice at once is undefined. + // + // WARNING: On MacOS and iOS, `onSignal()` will only see process-level signals, NOT + // thread-specific signals (i.e. not those sent with pthread_kill()). This is a limitation of + // Apple's implemnetation of kqueue() introduced in MacOS 10.14 which Apple says is not a bug. + // See: https://github.com/libevent/libevent/issues/765 Consider using kj::Executor or + // kj::newPromiseAndCrossThreadFulfiller() for cross-thread communications instead of signals. + // If you must have signals, build KJ and your app with `-DKJ_USE_KQUEUE=0`, which will cause + // KJ to fall back to a generic poll()-based implementation that is less efficient but handles + // thread-specific signals. static void captureSignal(int signum); // Arranges for the given signal to be captured and handled via UnixEventPort, so that you may @@ -119,6 +150,14 @@ class UnixEventPort: public EventPort { // .then() continuation may not run immediately, we need a more precise way, hence we null out // the Maybe. // + // The caller must NOT null out `pid` on its own unless it cancels the Promise first. If the + // caller decides to cancel the Promise, and `pid` is still non-null after this cancellation, + // then the caller is expected to `waitpid()` on it BEFORE returning to the event loop again. + // Probably, the caller should kill() the child before waiting to avoid a hang. If the caller + // fails to do its own waitpid() before returning to the event loop, the child may become a + // zombie, or may be reaped automatically, depending on the platform -- since the caller does not + // know, the caller cannot try to reap the zombie later. + // // You must call `kj::UnixEventPort::captureChildExit()` early in your program if you want to use // `onChildExit()`. // @@ -151,24 +190,25 @@ class UnixEventPort: public EventPort { const MonotonicClock& clock; TimerImpl timerImpl; +#if !KJ_USE_KQUEUE SignalPromiseAdapter* signalHead = nullptr; SignalPromiseAdapter** signalTail = &signalHead; void gotSignal(const siginfo_t& siginfo); +#endif friend class TimerPromiseAdapter; #if KJ_USE_EPOLL + sigset_t originalMask; AutoCloseFd epollFd; - AutoCloseFd signalFd; AutoCloseFd eventFd; // Used for cross-thread wakeups. - sigset_t signalFdSigset; - // Signal mask as currently set on the signalFd. Tracked so we can detect whether or not it - // needs updating. - - bool doEpollWait(int timeout); + bool processEpollEvents(struct epoll_event events[], int n); +#elif KJ_USE_KQUEUE + AutoCloseFd kqueueFd; + bool doKqueueWait(struct timespec* timeout); #else class PollContext; @@ -183,11 +223,20 @@ class UnixEventPort: public EventPort { #endif #endif +#if !KJ_USE_KQUEUE struct ChildSet; Maybe> childSet; +#endif + + static void signalHandler(int, siginfo_t* siginfo, void*) noexcept; + static void registerSignalHandler(int signum); +#if !KJ_USE_EPOLL && !KJ_USE_KQUEUE && !KJ_USE_PIPE_FOR_WAKEUP + static void registerReservedSignal(); +#endif + static void ignoreSigpipe(); }; -class UnixEventPort::FdObserver { +class UnixEventPort::FdObserver: private AsyncObject { // Object which watches a file descriptor to determine when it is readable or writable. // // For listen sockets, "readable" means that there is a connection to accept(). For everything @@ -216,7 +265,7 @@ class UnixEventPort::FdObserver { ~FdObserver() noexcept(false); - KJ_DISALLOW_COPY(FdObserver); + KJ_DISALLOW_COPY_AND_MOVE(FdObserver); Promise whenBecomesReadable(); // Resolves the next time the file descriptor transitions from having no data to read to having @@ -287,7 +336,7 @@ class UnixEventPort::FdObserver { // has not yet resolved. If you do this, the previous promise may throw an exception. // // WARNING: This has some known weird behavior on macOS. See - // https://github.com/sandstorm-io/capnproto/issues/374. + // https://github.com/capnproto/capnproto/issues/374. Promise whenWriteDisconnected(); // Resolves when poll() on the file descriptor reports POLLHUP or POLLERR. @@ -306,7 +355,11 @@ class UnixEventPort::FdObserver { Maybe atEnd; +#if KJ_USE_KQUEUE + void fire(struct kevent event); +#else void fire(short events); +#endif #if !KJ_USE_EPOLL FdObserver* next; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-win32.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-win32.c++ index ba0ee88b8e2..ed82e552065 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-win32.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-win32.c++ @@ -22,7 +22,7 @@ #if _WIN32 // Request Vista-level APIs. -#include "win32-api-version.h" +#include #include "async-win32.h" #include "debug.h" @@ -158,6 +158,11 @@ Own Win32IocpEventPort::observeSignalState(HANDL } bool Win32IocpEventPort::wait() { + // It's possible that a wake event was received and discarded during ~IoPromiseAdapter. We + // need to check for that now. Otherwise, calling waitIocp may cause it to hang forever. + if (receivedWake()) { + return true; + } waitIocp(timerImpl.timeoutToNextEvent(clock.now(), MILLISECONDS, INFINITE - 1) .map([](uint64_t t) -> DWORD { return t; }) .orDefault(INFINITE)); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-win32.h b/libs/EXTERNAL/capnproto/c++/src/kj/async-win32.h index 2085118c792..ddf4987a7a5 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-win32.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-win32.h @@ -27,7 +27,7 @@ // Include windows.h as lean as possible. (If you need more of the Windows API for your app, // #include windows.h yourself before including this header.) -#include "win32-api-version.h" +#include #include "async.h" #include "timer.h" @@ -36,7 +36,9 @@ #include #include -#include "windows-sanity.h" +#include + +KJ_BEGIN_HEADER namespace kj { @@ -118,7 +120,7 @@ class Win32EventPort: public EventPort { // Returns a promise that completes the next time the handle enters the signaled state. // // Depending on the type of handle, the handle may automatically be reset to a non-signaled - // state before the promise resolves. The underlying implementaiton uses WaitForSingleObject() + // state before the promise resolves. The underlying implementation uses WaitForSingleObject() // or an equivalent wait call, so check the documentation for that to understand the semantics. // // If the handle is a mutex and it is abandoned without being unlocked, the promise breaks with @@ -177,7 +179,7 @@ class Win32WaitObjectThreadPool { bool finishedMainThreadWait(DWORD returnCode); // Call immediately after invoking WaitForMultipleObjects() or similar in the main thread, - // passing the value returend by that call. Returns true if the event indicated by `returnCode` + // passing the value returned by that call. Returns true if the event indicated by `returnCode` // has been handled (i.e. it was WAIT_OBJECT_n or WAIT_ABANDONED_n where n is in-range for the // last call to prepareMainThreadWait()). }; @@ -227,3 +229,5 @@ class Win32IocpEventPort final: public Win32EventPort { }; } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async-xthread-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-xthread-test.c++ index 7d6fa80d6e5..b6bd237e196 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async-xthread-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async-xthread-test.c++ @@ -99,7 +99,7 @@ KJ_TEST("synchonous simple cross-thread events") { })(); } -KJ_TEST("asynchonous simple cross-thread events") { +KJ_TEST("asynchronous simple cross-thread events") { MutexGuarded> executor; // to get the Executor from the other thread Own> fulfiller; // accessed only from the subthread thread_local bool isChild = false; // to assert which thread we're in @@ -210,7 +210,7 @@ KJ_TEST("synchonous promise cross-thread events") { })(); } -KJ_TEST("asynchonous promise cross-thread events") { +KJ_TEST("asynchronous promise cross-thread events") { MutexGuarded> executor; // to get the Executor from the other thread Own> fulfiller; // accessed only from the subthread Promise promise = nullptr; // accessed only from the subthread diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/async.c++ index 0d2db1a1f4e..7c76f1d5026 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async.c++ @@ -26,11 +26,11 @@ // so this check isn't appropriate for us. #if _WIN32 || __CYGWIN__ -#include "win32-api-version.h" +#include #elif __APPLE__ // getcontext() and friends are marked deprecated on MacOS but seemingly no replacement is // provided. It appears as if they deprecated it solely because the standards bodies deprecated it, -// which they seemingly did mainly because the proper sematics are too difficult for them to +// which they seemingly did mainly because the proper semantics are too difficult for them to // define. I doubt MacOS would actually remove these functions as they are widely used. But if they // do, then I guess we'll need to fall back to using setjmp()/longjmp(), and some sort of hack // involving sigaltstack() (and generating a fake signal I guess) in order to initialize the fiber @@ -48,10 +48,11 @@ #include "function.h" #include "list.h" #include +#include #if _WIN32 || __CYGWIN__ #include // for Sleep(0) and fibers -#include "windows-sanity.h" +#include #else #if KJ_USE_FIBERS @@ -77,12 +78,24 @@ #include +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) +// Clang's address sanitizer requires special hints when switching fibers, especially in order for +// stack-use-after-return handling to work right. +// +// TODO(someday): Does GCC's sanitizer, flagged by __SANITIZE_ADDRESS__, have these hints too? I +// don't know and am not in a position to test, so I'm assuming not for now. +#include +#else +// Nop the hints so that we don't have to put #ifdefs around every use. +#define __sanitizer_start_switch_fiber(...) +#define __sanitizer_finish_switch_fiber(...) +#endif + #if _MSC_VER && !__clang__ // MSVC's atomic intrinsics are weird and different, whereas the C++ standard atomics match the GCC // builtins -- except for requiring the obnoxious std::atomic wrapper. So, on MSVC let's just // #define the builtins based on the C++ library, reinterpret-casting native types to // std::atomic... this is cheating but ugh, whatever. -#include template static std::atomic* reinterpretAtomic(T* ptr) { return reinterpret_cast*>(ptr); } #define __atomic_store_n(ptr, val, order) \ @@ -103,6 +116,51 @@ namespace kj { namespace { +KJ_THREADLOCAL_PTR(DisallowAsyncDestructorsScope) disallowAsyncDestructorsScope = nullptr; + +} // namespace + +AsyncObject::~AsyncObject() { + if (disallowAsyncDestructorsScope != nullptr) { + // If we try to do the KJ_FAIL_REQUIRE here (declaring `~AsyncObject()` itself to be noexcept), + // it seems to have a non-negligible performance impact in the HTTP benchmark. My guess is that + // it's because it breaks inlining of `~AsyncObject()` into various subclass destructors that + // are defined inside this file, which are some of the biggest ones. By forcing the actual + // failure code out into a separate function we get a little performance boost. + failed(); + } +} + +void AsyncObject::failed() noexcept { + // Since the method is noexcept, this will abort the process. + KJ_FAIL_REQUIRE( + kj::str("KJ async object being destroyed when not allowed: ", + disallowAsyncDestructorsScope->reason)); +} + +DisallowAsyncDestructorsScope::DisallowAsyncDestructorsScope(kj::StringPtr reason) + : reason(reason), previousValue(disallowAsyncDestructorsScope) { + requireOnStack(this, "DisallowAsyncDestructorsScope must be allocated on the stack."); + disallowAsyncDestructorsScope = this; +} + +DisallowAsyncDestructorsScope::~DisallowAsyncDestructorsScope() { + disallowAsyncDestructorsScope = previousValue; +} + +AllowAsyncDestructorsScope::AllowAsyncDestructorsScope() + : previousValue(disallowAsyncDestructorsScope) { + requireOnStack(this, "AllowAsyncDestructorsScope must be allocated on the stack."); + disallowAsyncDestructorsScope = nullptr; +} +AllowAsyncDestructorsScope::~AllowAsyncDestructorsScope() { + disallowAsyncDestructorsScope = previousValue; +} + +// ======================================================================================= + +namespace { + KJ_THREADLOCAL_PTR(EventLoop) threadLocalEventLoop = nullptr; #define _kJ_ALREADY_READY reinterpret_cast< ::kj::_::Event*>(1) @@ -115,7 +173,8 @@ EventLoop& currentEventLoop() { class RootEvent: public _::Event { public: - RootEvent(_::PromiseNode* node, void* traceAddr): node(node), traceAddr(traceAddr) {} + RootEvent(_::PromiseNode* node, void* traceAddr, SourceLocation location) + : Event(location), node(node), traceAddr(traceAddr) {} bool fired = false; @@ -138,45 +197,6 @@ struct DummyFunctor { void operator()() {}; }; -class YieldPromiseNode final: public _::PromiseNode { -public: - void onReady(_::Event* event) noexcept override { - if (event) event->armBreadthFirst(); - } - void get(_::ExceptionOrValue& output) noexcept override { - output.as<_::Void>() = _::Void(); - } - void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { - builder.add(reinterpret_cast(&kj::evalLater)); - } -}; - -class YieldHarderPromiseNode final: public _::PromiseNode { -public: - void onReady(_::Event* event) noexcept override { - if (event) event->armLast(); - } - void get(_::ExceptionOrValue& output) noexcept override { - output.as<_::Void>() = _::Void(); - } - void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { - builder.add(reinterpret_cast(&kj::evalLast)); - } -}; - -class NeverDonePromiseNode final: public _::PromiseNode { -public: - void onReady(_::Event* event) noexcept override { - // ignore - } - void get(_::ExceptionOrValue& output) noexcept override { - KJ_FAIL_REQUIRE("Not ready."); - } - void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { - builder.add(_::getMethodStartAddress(kj::NEVER_DONE, &_::NeverDone::wait)); - } -}; - } // namespace // ======================================================================================= @@ -261,19 +281,22 @@ void Canceler::AdapterImpl::cancel(kj::Exception&& e) { // ======================================================================================= -TaskSet::TaskSet(TaskSet::ErrorHandler& errorHandler) - : errorHandler(errorHandler) {} -class TaskSet::Task final: public _::Event { +TaskSet::TaskSet(TaskSet::ErrorHandler& errorHandler, SourceLocation location) + : errorHandler(errorHandler), location(location) {} + +class TaskSet::Task final: public _::PromiseArenaMember, public _::Event { public: - Task(TaskSet& taskSet, Own<_::PromiseNode>&& nodeParam) - : taskSet(taskSet), node(kj::mv(nodeParam)) { + Task(_::OwnPromiseNode&& nodeParam, TaskSet& taskSet) + : Event(taskSet.location), taskSet(taskSet), node(kj::mv(nodeParam)) { node->setSelfPointer(&node); node->onReady(this); } - Own pop() { + void destroy() override { freePromise(this); } + + OwnTask pop() { KJ_IF_MAYBE(n, next) { n->get()->prev = prev; } - Own self = kj::mv(KJ_ASSERT_NONNULL(*prev)); + OwnTask self = kj::mv(KJ_ASSERT_NONNULL(*prev)); KJ_ASSERT(self.get() == this); *prev = kj::mv(next); next = nullptr; @@ -281,8 +304,8 @@ public: return self; } - Maybe> next; - Maybe>* prev = nullptr; + Maybe next; + Maybe* prev = nullptr; kj::String trace() { void* space[32]; @@ -304,14 +327,12 @@ protected: result.addException(kj::mv(*exception)); } - // Call the error handler if there was an exception. - KJ_IF_MAYBE(e, result.exception) { - taskSet.errorHandler.taskFailed(kj::mv(*e)); - } - - // Remove from the task list. + // Remove from the task list. Do this before calling taskFailed(), so that taskFailed() can + // safely call clear(). auto self = pop(); + // We'll also process onEmpty() now, just in case `taskFailed()` actually destroys the whole + // `TaskSet`. KJ_IF_MAYBE(f, taskSet.emptyFulfiller) { if (taskSet.tasks == nullptr) { f->get()->fulfill(); @@ -319,7 +340,12 @@ protected: } } - return mv(self); + // Call the error handler if there was an exception. + KJ_IF_MAYBE(e, result.exception) { + taskSet.errorHandler.taskFailed(kj::mv(*e)); + } + + return Own(mv(self)); } void traceEvent(_::TraceBuilder& builder) override { @@ -330,7 +356,7 @@ protected: private: TaskSet& taskSet; - Own<_::PromiseNode> node; + _::OwnPromiseNode node; }; TaskSet::~TaskSet() noexcept(false) { @@ -344,7 +370,7 @@ TaskSet::~TaskSet() noexcept(false) { } void TaskSet::add(Promise&& promise) { - auto task = heap(*this, _::PromiseNode::from(kj::mv(promise))); + auto task = _::appendPromise(_::PromiseNode::from(kj::mv(promise)), *this); KJ_IF_MAYBE(head, tasks) { head->get()->prev = &task->next; task->next = kj::mv(tasks); @@ -356,7 +382,7 @@ void TaskSet::add(Promise&& promise) { kj::String TaskSet::trace() { kj::Vector traces; - Maybe>* ptr = &tasks; + Maybe* ptr = &tasks; for (;;) { KJ_IF_MAYBE(task, *ptr) { traces.add(task->get()->trace()); @@ -385,6 +411,14 @@ Promise TaskSet::onEmpty() { } } +void TaskSet::clear() { + tasks = nullptr; + + KJ_IF_MAYBE(fulfiller, emptyFulfiller) { + fulfiller->get()->fulfill(); + } +} + // ======================================================================================= namespace { @@ -844,8 +878,9 @@ struct Executor::Impl { namespace _ { // (private) XThreadEvent::XThreadEvent( - ExceptionOrValue& result, const Executor& targetExecutor, void* funcTracePtr) - : Event(targetExecutor.getLoop()), result(result), funcTracePtr(funcTracePtr), + ExceptionOrValue& result, const Executor& targetExecutor, EventLoop& loop, + void* funcTracePtr, SourceLocation location) + : Event(loop, location), result(result), funcTracePtr(funcTracePtr), targetExecutor(targetExecutor.addRef()) {} void XThreadEvent::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { @@ -1038,7 +1073,7 @@ void XThreadEvent::done() { lock->executing.remove(*this); break; case CANCELING: - // Sending thread requested cancelation, but we're done anyway, so it doesn't matter at this + // Sending thread requested cancellation, but we're done anyway, so it doesn't matter at this // point. lock->cancel.remove(*this); break; @@ -1120,36 +1155,33 @@ XThreadPaf::XThreadPaf() : state(WAITING), executor(getCurrentThreadExecutor()) {} XThreadPaf::~XThreadPaf() noexcept(false) {} -void XThreadPaf::Disposer::disposeImpl(void* pointer) const { - XThreadPaf* obj = reinterpret_cast(pointer); +void XThreadPaf::destroy() { auto oldState = WAITING; - if (__atomic_load_n(&obj->state, __ATOMIC_ACQUIRE) == DISPATCHED) { + if (__atomic_load_n(&state, __ATOMIC_ACQUIRE) == DISPATCHED) { // Common case: Promise was fully fulfilled and dispatched, no need for locking. - delete obj; - } else if (__atomic_compare_exchange_n(&obj->state, &oldState, CANCELED, false, + delete this; + } else if (__atomic_compare_exchange_n(&state, &oldState, CANCELED, false, __ATOMIC_ACQUIRE, __ATOMIC_ACQUIRE)) { // State transitioned from WAITING to CANCELED, so now it's the fulfiller's job to destroy the // object. } else { // Whoops, another thread is already in the process of fulfilling this promise. We'll have to // wait for it to finish and transition the state to FULFILLED. - obj->executor.impl->state.when([&](auto&) { - return obj->state == FULFILLED || obj->state == DISPATCHED; + executor.impl->state.when([&](auto&) { + return state == FULFILLED || state == DISPATCHED; }, [&](Executor::Impl::State& exState) { - if (obj->state == FULFILLED) { + if (state == FULFILLED) { // The object is on the queue but was not yet dispatched. Remove it. - exState.fulfilled.remove(*obj); + exState.fulfilled.remove(*this); } }); // It's ours now, delete it. - delete obj; + delete this; } } -const XThreadPaf::Disposer XThreadPaf::DISPOSER; - void XThreadPaf::onReady(Event* event) noexcept { onReadyEvent.init(event); } @@ -1328,43 +1360,70 @@ struct FiberStack::Impl { jmp_buf fiberJmpBuf; jmp_buf originalJmpBuf; +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // Stuff that we need to pass to __sanitizer_start_switch_fiber() / + // __sanitizer_finish_switch_fiber() when using ASAN. + + void* originalFakeStack = nullptr; + void* fiberFakeStack = nullptr; + // Pointer to ASAN "fake stack" associated with the fiber and its calling stack. Filled in by + // __sanitizer_start_switch_fiber() before switching away, consumed by + // __sanitizer_finish_switch_fiber() upon switching back. + + void const* originalBottom; + size_t originalSize; + // Size and location of the original stack before switching fibers. These are filled in by + // __sanitizer_finish_switch_fiber() after the switch, and must be passed to + // __sanitizer_start_switch_fiber() when switching back later. +#endif + static Impl* alloc(size_t stackSize, ucontext_t* context) { #ifndef MAP_ANONYMOUS #define MAP_ANONYMOUS MAP_ANON #endif -#ifndef MAP_STACK -#define MAP_STACK 0 -#endif - size_t pageSize = getPageSize(); - size_t allocSize = stackSize + pageSize; // size plus guard page + size_t allocSize = stackSize + pageSize; // size plus guard page and impl // Allocate virtual address space for the stack but make it inaccessible initially. // TODO(someday): Does it make sense to use MAP_GROWSDOWN on Linux? It's a kind of bizarre flag // that causes the mapping to automatically allocate extra pages (beyond the range specified) - // until it hits something... - void* stack = mmap(nullptr, allocSize, PROT_NONE, - MAP_ANONYMOUS | MAP_PRIVATE | MAP_STACK, -1, 0); - if (stack == MAP_FAILED) { + // until it hits something... Note that on FreeBSD, MAP_STACK has the effect that + // MAP_GROWSDOWN has on Linux. (MAP_STACK, meanwhile, has no effect on Linux.) + void* stackMapping = mmap(nullptr, allocSize, PROT_NONE, + MAP_ANONYMOUS | MAP_PRIVATE, -1, 0); + if (stackMapping == MAP_FAILED) { KJ_FAIL_SYSCALL("mmap(new stack)", errno); } KJ_ON_SCOPE_FAILURE({ - KJ_SYSCALL(munmap(stack, allocSize)) { break; } + KJ_SYSCALL(munmap(stackMapping, allocSize)) { break; } }); + void* stack = reinterpret_cast(stackMapping) + pageSize; // Now mark everything except the guard page as read-write. We assume the stack grows down, so // the guard page is at the beginning. No modern architecture uses stacks that grow up. - KJ_SYSCALL(mprotect(reinterpret_cast(stack) + pageSize, - stackSize, PROT_READ | PROT_WRITE)); + KJ_SYSCALL(mprotect(stack, stackSize, PROT_READ | PROT_WRITE)); // Stick `Impl` at the top of the stack. - Impl* impl = (reinterpret_cast(reinterpret_cast(stack) + allocSize) - 1); + Impl* impl = (reinterpret_cast(reinterpret_cast(stack) + stackSize) - 1); // Note: mmap() allocates zero'd pages so we don't have to memset() anything here. KJ_SYSCALL(getcontext(context)); - context->uc_stack.ss_size = allocSize - sizeof(Impl); +#if __APPLE__ && __aarch64__ + // Per issue #1386, apple on arm64 zeros the entire configured stack. + // But this is redundant, since we just allocated the stack with mmap() which + // returns zero'd pages. Re-zeroing is both slow and results in prematurely + // allocating pages we may not need -- it's normal for stacks to rely heavily + // on lazy page allocation to avoid wasting memory. Instead, we lie: + // we allocate the full size, but tell the ucontext the stack is the last + // page only. This appears to work as no particular bounds checks or + // anything are set up based on what we say here. + context->uc_stack.ss_size = min(pageSize, stackSize) - sizeof(Impl); + context->uc_stack.ss_sp = reinterpret_cast(stack) + stackSize - min(pageSize, stackSize); +#else + context->uc_stack.ss_size = stackSize - sizeof(Impl); context->uc_stack.ss_sp = reinterpret_cast(stack); +#endif context->uc_stack.ss_flags = 0; // We don't use uc_link since our fiber start routine runs forever in a loop to allow for // reuse. When we're done with the fiber, we just destroy it, without switching to it's @@ -1385,7 +1444,7 @@ struct FiberStack::Impl { #ifndef _SC_PAGESIZE #define _SC_PAGESIZE _SC_PAGE_SIZE #endif - static size_t result = sysconf(_SC_PAGE_SIZE); + static size_t result = sysconf(_SC_PAGESIZE); return result; } }; @@ -1409,6 +1468,9 @@ struct FiberStack::StartRoutine { auto& stack = *reinterpret_cast(ptr); + __sanitizer_finish_switch_fiber(nullptr, + &stack.impl->originalBottom, &stack.impl->originalSize); + // We first switch to the fiber inside of the FiberStack constructor. This is just for // initialization purposes, and we're expected to switch back immediately. stack.switchToMain(); @@ -1464,9 +1526,11 @@ FiberStack::FiberStack(size_t stackSizeParam) makecontext(&context, reinterpret_cast(&StartRoutine::run), 2, arg1, arg2); + __sanitizer_start_switch_fiber(&impl->originalFakeStack, impl, stackSize - sizeof(Impl)); if (_setjmp(impl->originalJmpBuf) == 0) { setcontext(&context); } + __sanitizer_finish_switch_fiber(impl->originalFakeStack, nullptr, nullptr); #endif #else #if KJ_NO_EXCEPTIONS @@ -1501,14 +1565,14 @@ void FiberStack::initialize(SynchronousFunc& func) { this->main = &func; } -FiberBase::FiberBase(size_t stackSize, _::ExceptionOrValue& result) - : state(WAITING), stack(kj::heap(stackSize)), result(result) { +FiberBase::FiberBase(size_t stackSize, _::ExceptionOrValue& result, SourceLocation location) + : Event(location), state(WAITING), stack(kj::heap(stackSize)), result(result) { stack->initialize(*this); ensureThreadCanRunFibers(); } -FiberBase::FiberBase(const FiberPool& pool, _::ExceptionOrValue& result) - : state(WAITING), result(result) { +FiberBase::FiberBase(const FiberPool& pool, _::ExceptionOrValue& result, SourceLocation location) + : Event(location), state(WAITING), result(result) { stack = pool.impl->takeStack(); stack->initialize(*this); ensureThreadCanRunFibers(); @@ -1516,7 +1580,7 @@ FiberBase::FiberBase(const FiberPool& pool, _::ExceptionOrValue& result) FiberBase::~FiberBase() noexcept(false) {} -void FiberBase::destroy() { +void FiberBase::cancel() { // Called by `~Fiber()` to begin teardown. We can't do this work in `~FiberBase()` because the // `Fiber` subclass contains members that may still be in-use until the fiber stops. @@ -1538,7 +1602,7 @@ void FiberBase::destroy() { case RUNNING: case CANCELED: // Bad news. - KJ_LOG(FATAL, "fiber tried to destroy itself"); + KJ_LOG(FATAL, "fiber tried to cancel itself"); ::abort(); break; @@ -1563,9 +1627,11 @@ void FiberStack::switchToFiber() { #if _WIN32 || __CYGWIN__ SwitchToFiber(osFiber); #else + __sanitizer_start_switch_fiber(&impl->originalFakeStack, impl, stackSize - sizeof(Impl)); if (_setjmp(impl->originalJmpBuf) == 0) { _longjmp(impl->fiberJmpBuf, 1); } + __sanitizer_finish_switch_fiber(impl->originalFakeStack, nullptr, nullptr); #endif #endif } @@ -1576,9 +1642,21 @@ void FiberStack::switchToMain() { #if _WIN32 || __CYGWIN__ SwitchToFiber(getMainWin32Fiber()); #else + // TODO(someady): In theory, the last time we switch away from the fiber, we should pass `nullptr` + // for the first argument here, so that ASAN destroys the fake stack. However, as currently + // designed, we don't actually know if we're switching away for the last time. It's understood + // that when we call switchToMain() in FiberStack::run(), then the main stack is allowed to + // destroy the fiber, or reuse it. I don't want to develop a mechanism to switch back to the + // fiber on final destruction just to get the hints right, so instead we leak the fake stack. + // This doesn't seem to cause any problems -- it's not even detected by ASAN as a memory leak. + // But if we wanted to run ASAN builds in production or something, it might be an issue. + __sanitizer_start_switch_fiber(&impl->fiberFakeStack, + impl->originalBottom, impl->originalSize); if (_setjmp(impl->fiberJmpBuf) == 0) { _longjmp(impl->originalJmpBuf, 1); } + __sanitizer_finish_switch_fiber(impl->fiberFakeStack, + &impl->originalBottom, &impl->originalSize); #endif #endif } @@ -1795,16 +1873,19 @@ void EventLoop::poll() { } } -void WaitScope::poll() { +uint WaitScope::poll(uint maxTurnCount) { KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread."); KJ_REQUIRE(!loop.running, "poll() is not allowed from within event callbacks."); loop.running = true; KJ_DEFER(loop.running = false); + uint turnCount = 0; runOnStackPool([&]() { - for (;;) { - if (!loop.turn()) { + while (turnCount < maxTurnCount) { + if (loop.turn()) { + ++turnCount; + } else { // No events in the queue. Poll for I/O. loop.poll(); @@ -1815,6 +1896,7 @@ void WaitScope::poll() { } } }); + return turnCount; } void WaitScope::cancelAllDetached() { @@ -1838,7 +1920,8 @@ static kj::CanceledException fiberCanceledException() { }; #endif -void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope& waitScope) { +void waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, WaitScope& waitScope, + SourceLocation location) { EventLoop& loop = waitScope.loop; KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread."); @@ -1873,7 +1956,7 @@ void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope #endif KJ_REQUIRE(!loop.running, "wait() is not allowed from within event callbacks."); - RootEvent doneEvent(node, reinterpret_cast(&waitImpl)); + RootEvent doneEvent(node, reinterpret_cast(&waitImpl), location); node->setSelfPointer(&node); node->onReady(&doneEvent); @@ -1917,13 +2000,13 @@ void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope }); } -bool pollImpl(_::PromiseNode& node, WaitScope& waitScope) { +bool pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location) { EventLoop& loop = waitScope.loop; KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread."); KJ_REQUIRE(waitScope.fiber == nullptr, "poll() is not supported in fibers."); KJ_REQUIRE(!loop.running, "poll() is not allowed from within event callbacks."); - RootEvent doneEvent(&node, reinterpret_cast(&pollImpl)); + RootEvent doneEvent(&node, reinterpret_cast(&pollImpl), location); node.onReady(&doneEvent); loop.running = true; @@ -1954,20 +2037,84 @@ bool pollImpl(_::PromiseNode& node, WaitScope& waitScope) { } Promise yield() { - return _::PromiseNode::to>(kj::heap()); + class YieldPromiseNode final: public _::PromiseNode { + public: + void destroy() override {} + + void onReady(_::Event* event) noexcept override { + if (event) event->armBreadthFirst(); + } + void get(_::ExceptionOrValue& output) noexcept override { + output.as<_::Void>() = _::Void(); + } + void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { + builder.add(reinterpret_cast(&kj::evalLater)); + } + }; + + static YieldPromiseNode NODE; + return _::PromiseNode::to>(OwnPromiseNode(&NODE)); } Promise yieldHarder() { - return _::PromiseNode::to>(kj::heap()); + class YieldHarderPromiseNode final: public _::PromiseNode { + public: + void destroy() override {} + + void onReady(_::Event* event) noexcept override { + if (event) event->armLast(); + } + void get(_::ExceptionOrValue& output) noexcept override { + output.as<_::Void>() = _::Void(); + } + void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { + builder.add(reinterpret_cast(&kj::evalLast)); + } + }; + + static YieldHarderPromiseNode NODE; + return _::PromiseNode::to>(OwnPromiseNode(&NODE)); +} + +OwnPromiseNode readyNow() { + class ReadyNowPromiseNode: public ImmediatePromiseNodeBase { + // This is like `ConstPromiseNode`, but the compiler won't let me pass a literal + // value of type `Void` as a template parameter. (Might require C++20?) + + public: + void destroy() override {} + void get(ExceptionOrValue& output) noexcept override { + output.as() = Void(); + } + }; + + static ReadyNowPromiseNode NODE; + return OwnPromiseNode(&NODE); } -Own neverDone() { - return kj::heap(); +OwnPromiseNode neverDone() { + class NeverDonePromiseNode final: public _::PromiseNode { + public: + void destroy() override {} + + void onReady(_::Event* event) noexcept override { + // ignore + } + void get(_::ExceptionOrValue& output) noexcept override { + KJ_FAIL_REQUIRE("Not ready."); + } + void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { + builder.add(_::getMethodStartAddress(kj::NEVER_DONE, &_::NeverDone::wait)); + } + }; + + static NeverDonePromiseNode NODE; + return OwnPromiseNode(&NODE); } -void NeverDone::wait(WaitScope& waitScope) const { +void NeverDone::wait(WaitScope& waitScope, SourceLocation location) const { ExceptionOr dummy; - waitImpl(neverDone(), dummy, waitScope); + waitImpl(neverDone(), dummy, waitScope, location); KJ_UNREACHABLE; } @@ -1977,13 +2124,23 @@ void detach(kj::Promise&& promise) { loop.daemons->add(kj::mv(promise)); } -Event::Event() - : loop(currentEventLoop()), next(nullptr), prev(nullptr) {} +Event::Event(SourceLocation location) + : loop(currentEventLoop()), next(nullptr), prev(nullptr), location(location) {} -Event::Event(kj::EventLoop& loop) - : loop(loop), next(nullptr), prev(nullptr) {} +Event::Event(kj::EventLoop& loop, SourceLocation location) + : loop(loop), next(nullptr), prev(nullptr), location(location) {} Event::~Event() noexcept(false) { + live = 0; + + // Prevent compiler from eliding this store above. This line probably isn't needed because there + // are complex calls later in this destructor, and the compiler probably can't prove that they + // won't come back and examine `live`, so it won't elide the write anyway. However, an + // atomic_signal_fence is also sufficient to tell the compiler that a signal handler might access + // `live`, so it won't optimize away the write. Note that a signal fence does not produce + // any instructions, it just blocks compiler optimizations. + std::atomic_signal_fence(std::memory_order_acq_rel); + disarm(); KJ_REQUIRE(!firing, "Promise callback destroyed itself."); @@ -1993,6 +2150,11 @@ void Event::armDepthFirst() { KJ_REQUIRE(threadLocalEventLoop == &loop || threadLocalEventLoop == nullptr, "Event armed from different thread than it was created in. You must use " "Executor to queue events cross-thread."); + if (live != MAGIC_LIVE_VALUE) { + ([this]() noexcept { + KJ_FAIL_ASSERT("tried to arm Event after it was destroyed", location); + })(); + } if (prev == nullptr) { next = *loop.depthFirstInsertPoint; @@ -2019,6 +2181,11 @@ void Event::armBreadthFirst() { KJ_REQUIRE(threadLocalEventLoop == &loop || threadLocalEventLoop == nullptr, "Event armed from different thread than it was created in. You must use " "Executor to queue events cross-thread."); + if (live != MAGIC_LIVE_VALUE) { + ([this]() noexcept { + KJ_FAIL_ASSERT("tried to arm Event after it was destroyed", location); + })(); + } if (prev == nullptr) { next = *loop.breadthFirstInsertPoint; @@ -2042,6 +2209,11 @@ void Event::armLast() { KJ_REQUIRE(threadLocalEventLoop == &loop || threadLocalEventLoop == nullptr, "Event armed from different thread than it was created in. You must use " "Executor to queue events cross-thread."); + if (live != MAGIC_LIVE_VALUE) { + ([this]() noexcept { + KJ_FAIL_ASSERT("tried to arm Event after it was destroyed", location); + })(); + } if (prev == nullptr) { next = *loop.breadthFirstInsertPoint; @@ -2062,6 +2234,10 @@ void Event::armLast() { } } +bool Event::isNext() { + return loop.running && loop.head == this; +} + void Event::disarm() { if (prev != nullptr) { if (threadLocalEventLoop != &loop && threadLocalEventLoop != nullptr) { @@ -2132,7 +2308,7 @@ kj::String PromiseBase::trace() { return kj::str(builder); } -void PromiseNode::setSelfPointer(Own* selfPtr) noexcept {} +void PromiseNode::setSelfPointer(OwnPromiseNode* selfPtr) noexcept {} void PromiseNode::OnReadyEvent::init(Event* newEvent) { if (event == _kJ_ALREADY_READY) { @@ -2187,13 +2363,15 @@ void ImmediatePromiseNodeBase::tracePromise(TraceBuilder& builder, bool stopAtNe ImmediateBrokenPromiseNode::ImmediateBrokenPromiseNode(Exception&& exception) : exception(kj::mv(exception)) {} +void ImmediateBrokenPromiseNode::destroy() { freePromise(this); } + void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept { output.exception = kj::mv(exception); } // ------------------------------------------------------------------- -AttachmentPromiseNodeBase::AttachmentPromiseNodeBase(Own&& dependencyParam) +AttachmentPromiseNodeBase::AttachmentPromiseNodeBase(OwnPromiseNode&& dependencyParam) : dependency(kj::mv(dependencyParam)) { dependency->setSelfPointer(&dependency); } @@ -2220,7 +2398,7 @@ void AttachmentPromiseNodeBase::dropDependency() { // ------------------------------------------------------------------- TransformPromiseNodeBase::TransformPromiseNodeBase( - Own&& dependencyParam, void* continuationTracePtr) + OwnPromiseNode&& dependencyParam, void* continuationTracePtr) : dependency(kj::mv(dependencyParam)), continuationTracePtr(continuationTracePtr) { dependency->setSelfPointer(&dependency); } @@ -2268,7 +2446,7 @@ void TransformPromiseNodeBase::getDepResult(ExceptionOrValue& output) { // ------------------------------------------------------------------- -ForkBranchBase::ForkBranchBase(Own&& hubParam): hub(kj::mv(hubParam)) { +ForkBranchBase::ForkBranchBase(OwnForkHubBase&& hubParam): hub(kj::mv(hubParam)) { if (hub->tailBranch == nullptr) { onReadyEvent.arm(); } else { @@ -2317,8 +2495,9 @@ void ForkBranchBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { // ------------------------------------------------------------------- -ForkHubBase::ForkHubBase(Own&& innerParam, ExceptionOrValue& resultRef) - : inner(kj::mv(innerParam)), resultRef(resultRef) { +ForkHubBase::ForkHubBase(OwnPromiseNode&& innerParam, ExceptionOrValue& resultRef, + SourceLocation location) + : Event(location), inner(kj::mv(innerParam)), resultRef(resultRef) { inner->setSelfPointer(&inner); inner->onReady(this); } @@ -2358,14 +2537,16 @@ void ForkHubBase::traceEvent(TraceBuilder& builder) { // ------------------------------------------------------------------- -ChainPromiseNode::ChainPromiseNode(Own innerParam) - : state(STEP1), inner(kj::mv(innerParam)) { +ChainPromiseNode::ChainPromiseNode(OwnPromiseNode innerParam, SourceLocation location) + : Event(location), state(STEP1), inner(kj::mv(innerParam)) { inner->setSelfPointer(&inner); inner->onReady(this); } ChainPromiseNode::~ChainPromiseNode() noexcept(false) {} +void ChainPromiseNode::destroy() { freePromise(this); } + void ChainPromiseNode::onReady(Event* event) noexcept { switch (state) { case STEP1: @@ -2378,7 +2559,7 @@ void ChainPromiseNode::onReady(Event* event) noexcept { KJ_UNREACHABLE; } -void ChainPromiseNode::setSelfPointer(Own* selfPtr) noexcept { +void ChainPromiseNode::setSelfPointer(OwnPromiseNode* selfPtr) noexcept { if (state == STEP2) { *selfPtr = kj::mv(inner); // deletes this! selfPtr->get()->setSelfPointer(selfPtr); @@ -2422,7 +2603,7 @@ Maybe> ChainPromiseNode::fire() { // There is an exception. If there is also a value, delete it. kj::runCatchingExceptions([&]() { intermediate.value = nullptr; }); // Now set step2 to a rejected promise. - inner = heap(kj::mv(*exception)); + inner = allocPromise(kj::mv(*exception)); } else KJ_IF_MAYBE(value, intermediate.value) { // There is a value and no exception. The value is itself a promise. Adopt it as our // step2. @@ -2444,7 +2625,7 @@ Maybe> ChainPromiseNode::fire() { } // Return our self-pointer so that the caller takes care of deleting it. - return Own(kj::mv(chain)); + return Own(kj::Own(kj::mv(chain))); } else { inner->setSelfPointer(&inner); if (onReadyEvent != nullptr) { @@ -2476,11 +2657,14 @@ void ChainPromiseNode::traceEvent(TraceBuilder& builder) { // ------------------------------------------------------------------- -ExclusiveJoinPromiseNode::ExclusiveJoinPromiseNode(Own left, Own right) - : left(*this, kj::mv(left)), right(*this, kj::mv(right)) {} +ExclusiveJoinPromiseNode::ExclusiveJoinPromiseNode( + OwnPromiseNode left, OwnPromiseNode right, SourceLocation location) + : left(*this, kj::mv(left), location), right(*this, kj::mv(right), location) {} ExclusiveJoinPromiseNode::~ExclusiveJoinPromiseNode() noexcept(false) {} +void ExclusiveJoinPromiseNode::destroy() { freePromise(this); } + void ExclusiveJoinPromiseNode::onReady(Event* event) noexcept { onReadyEvent.init(event); } @@ -2504,8 +2688,8 @@ void ExclusiveJoinPromiseNode::tracePromise(TraceBuilder& builder, bool stopAtNe } ExclusiveJoinPromiseNode::Branch::Branch( - ExclusiveJoinPromiseNode& joinNode, Own dependencyParam) - : joinNode(joinNode), dependency(kj::mv(dependencyParam)) { + ExclusiveJoinPromiseNode& joinNode, OwnPromiseNode dependencyParam, SourceLocation location) + : Event(location), joinNode(joinNode), dependency(kj::mv(dependencyParam)) { dependency->setSelfPointer(&dependency); dependency->onReady(this); } @@ -2548,14 +2732,15 @@ void ExclusiveJoinPromiseNode::Branch::traceEvent(TraceBuilder& builder) { // ------------------------------------------------------------------- ArrayJoinPromiseNodeBase::ArrayJoinPromiseNodeBase( - Array> promises, ExceptionOrValue* resultParts, size_t partSize) - : countLeft(promises.size()) { + Array promises, ExceptionOrValue* resultParts, size_t partSize, + SourceLocation location, ArrayJoinBehavior joinBehavior) + : joinBehavior(joinBehavior), countLeft(promises.size()) { // Make the branches. auto builder = heapArrayBuilder(promises.size()); for (uint i: indices(promises)) { ExceptionOrValue& output = *reinterpret_cast( reinterpret_cast(resultParts) + i * partSize); - builder.add(*this, kj::mv(promises[i]), output); + builder.add(*this, kj::mv(promises[i]), output, location); } branches = builder.finish(); @@ -2570,13 +2755,21 @@ void ArrayJoinPromiseNodeBase::onReady(Event* event) noexcept { } void ArrayJoinPromiseNodeBase::get(ExceptionOrValue& output) noexcept { - // If any of the elements threw exceptions, propagate them. for (auto& branch: branches) { - KJ_IF_MAYBE(exception, branch.getPart()) { + if (joinBehavior == ArrayJoinBehavior::LAZY) { + // This implements `joinPromises()`'s lazy evaluation semantics. + branch.dependency->get(branch.output); + } + + // If any of the elements threw exceptions, propagate them. + KJ_IF_MAYBE(exception, branch.output.exception) { output.addException(kj::mv(*exception)); } } + // We either failed fast, or waited for all promises. + KJ_DASSERT(countLeft == 0 || output.exception != nullptr); + if (output.exception == nullptr) { // No errors. The template subclass will need to fill in the result. getNoError(output); @@ -2596,8 +2789,9 @@ void ArrayJoinPromiseNodeBase::tracePromise(TraceBuilder& builder, bool stopAtNe } ArrayJoinPromiseNodeBase::Branch::Branch( - ArrayJoinPromiseNodeBase& joinNode, Own dependencyParam, ExceptionOrValue& output) - : joinNode(joinNode), dependency(kj::mv(dependencyParam)), output(output) { + ArrayJoinPromiseNodeBase& joinNode, OwnPromiseNode dependencyParam, ExceptionOrValue& output, + SourceLocation location) + : Event(location), joinNode(joinNode), dependency(kj::mv(dependencyParam)), output(output) { dependency->setSelfPointer(&dependency); dependency->onReady(this); } @@ -2605,9 +2799,20 @@ ArrayJoinPromiseNodeBase::Branch::Branch( ArrayJoinPromiseNodeBase::Branch::~Branch() noexcept(false) {} Maybe> ArrayJoinPromiseNodeBase::Branch::fire() { - if (--joinNode.countLeft == 0) { + if (--joinNode.countLeft == 0 && !joinNode.armed) { joinNode.onReadyEvent.arm(); + joinNode.armed = true; + } + + if (joinNode.joinBehavior == ArrayJoinBehavior::EAGER) { + // This implements `joinPromisesFailFast()`'s eager-evaluation semantics. + dependency->get(output); + if (output.exception != nullptr && !joinNode.armed) { + joinNode.onReadyEvent.arm(); + joinNode.armed = true; + } } + return nullptr; } @@ -2616,28 +2821,35 @@ void ArrayJoinPromiseNodeBase::Branch::traceEvent(TraceBuilder& builder) { joinNode.onReadyEvent.traceEvent(builder); } -Maybe ArrayJoinPromiseNodeBase::Branch::getPart() { - dependency->get(output); - return kj::mv(output.exception); -} - ArrayJoinPromiseNode::ArrayJoinPromiseNode( - Array> promises, Array> resultParts) - : ArrayJoinPromiseNodeBase(kj::mv(promises), resultParts.begin(), sizeof(ExceptionOr<_::Void>)), + Array promises, Array> resultParts, + SourceLocation location, ArrayJoinBehavior joinBehavior) + : ArrayJoinPromiseNodeBase(kj::mv(promises), resultParts.begin(), sizeof(ExceptionOr<_::Void>), + location, joinBehavior), resultParts(kj::mv(resultParts)) {} ArrayJoinPromiseNode::~ArrayJoinPromiseNode() {} +void ArrayJoinPromiseNode::destroy() { freePromise(this); } + void ArrayJoinPromiseNode::getNoError(ExceptionOrValue& output) noexcept { output.as<_::Void>() = _::Void(); } } // namespace _ (private) -Promise joinPromises(Array>&& promises) { - return _::PromiseNode::to>(kj::heap<_::ArrayJoinPromiseNode>( +Promise joinPromises(Array>&& promises, SourceLocation location) { + return _::PromiseNode::to>(_::allocPromise<_::ArrayJoinPromiseNode>( KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); }, - heapArray<_::ExceptionOr<_::Void>>(promises.size()))); + heapArray<_::ExceptionOr<_::Void>>(promises.size()), location, + _::ArrayJoinBehavior::LAZY)); +} + +Promise joinPromisesFailFast(Array>&& promises, SourceLocation location) { + return _::PromiseNode::to>(_::allocPromise<_::ArrayJoinPromiseNode>( + KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); }, + heapArray<_::ExceptionOr<_::Void>>(promises.size()), location, + _::ArrayJoinBehavior::EAGER)); } namespace _ { // (private) @@ -2645,8 +2857,8 @@ namespace _ { // (private) // ------------------------------------------------------------------- EagerPromiseNodeBase::EagerPromiseNodeBase( - Own&& dependencyParam, ExceptionOrValue& resultRef) - : dependency(kj::mv(dependencyParam)), resultRef(resultRef) { + OwnPromiseNode&& dependencyParam, ExceptionOrValue& resultRef, SourceLocation location) + : Event(location), dependency(kj::mv(dependencyParam)), resultRef(resultRef) { dependency->setSelfPointer(&dependency); dependency->onReady(this); } @@ -2727,4 +2939,218 @@ namespace _ { // (private) Promise IdentityFunc>::operator()() const { return READY_NOW; } } // namespace _ (private) + +// ------------------------------------------------------------------- + +#if KJ_HAS_COROUTINE + +namespace _ { // (private) + +CoroutineBase::CoroutineBase(stdcoro::coroutine_handle<> coroutine, ExceptionOrValue& resultRef, + SourceLocation location) + : Event(location), + coroutine(coroutine), + resultRef(resultRef) {} +CoroutineBase::~CoroutineBase() noexcept(false) { + readMaybe(maybeDisposalResults)->destructorRan = true; +} + +void CoroutineBase::unhandled_exception() { + // Pretty self-explanatory, we propagate the exception to the promise which owns us, unless + // we're being destroyed, in which case we propagate it back to our disposer. Note that all + // unhandled exceptions end up here, not just ones after the first co_await. + + auto exception = getCaughtExceptionAsKj(); + + KJ_IF_MAYBE(disposalResults, maybeDisposalResults) { + // Exception during coroutine destruction. Only record the first one. + if (disposalResults->exception == nullptr) { + disposalResults->exception = kj::mv(exception); + } + } else if (isWaiting()) { + // Exception during coroutine execution. + resultRef.addException(kj::mv(exception)); + scheduleResumption(); + } else { + // Okay, what could this mean? We've already been fulfilled or rejected, but we aren't being + // destroyed yet. The only possibility is that we are unwinding the coroutine frame due to a + // successful completion, and something in the frame threw. We can't already be rejected, + // because rejecting a coroutine involves throwing, which would have unwound the frame prior + // to setting `waiting = false`. + // + // Since we know we're unwinding due to a successful completion, we also know that whatever + // Event we may have armed has not yet fired, because we haven't had a chance to return to + // the event loop. + + // final_suspend() has not been called. +#if _MSC_VER && !defined(__clang__) + // See comment at `finalSuspendCalled`'s definition. + KJ_IASSERT(!finalSuspendCalled); +#else + KJ_IASSERT(!coroutine.done()); +#endif + + // Since final_suspend() hasn't been called, whatever Event is waiting on us has not fired, + // and will see this exception. + resultRef.addException(kj::mv(exception)); + } +} + +void CoroutineBase::onReady(Event* event) noexcept { + onReadyEvent.init(event); +} + +void CoroutineBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + if (stopAtNextEvent) return; + + KJ_IF_MAYBE(promise, promiseNodeForTrace) { + promise->tracePromise(builder, stopAtNextEvent); + } + + // Maybe returning the address of coroutine() will give us a function name with meaningful type + // information. (Narrator: It doesn't.) + builder.add(GetFunctorStartAddress<>::apply(coroutine)); +}; + +Maybe> CoroutineBase::fire() { + // Call Awaiter::await_resume() and proceed with the coroutine. Note that this will not destroy + // the coroutine if control flows off the end of it, because we return suspend_always() from + // final_suspend(). + // + // It's tempting to arrange to check for exceptions right now and reject the promise that owns + // us without resuming the coroutine, which would save us from throwing an exception when we + // already know where it's going. But, we don't really know: unlike in the KJ_NO_EXCEPTIONS + // case, the `co_await` might be in a try-catch block, so we have no choice but to resume and + // throw later. + // + // TODO(someday): If we ever support coroutines with -fno-exceptions, we'll need to reject the + // enclosing coroutine promise here, if the Awaiter's result is exceptional. + + promiseNodeForTrace = nullptr; + + coroutine.resume(); + + return nullptr; +} + +void CoroutineBase::traceEvent(TraceBuilder& builder) { + KJ_IF_MAYBE(promise, promiseNodeForTrace) { + promise->tracePromise(builder, true); + } + + // Maybe returning the address of coroutine() will give us a function name with meaningful type + // information. (Narrator: It doesn't.) + builder.add(GetFunctorStartAddress<>::apply(coroutine)); + + onReadyEvent.traceEvent(builder); +} + +void CoroutineBase::destroy() { + // Called by PromiseDisposer to delete the object. Basically a wrapper around coroutine.destroy() + // with some stuff to propagate exceptions appropriately. + + // Objects in the coroutine frame might throw from their destructors, so unhandled_exception() + // will need some way to communicate those exceptions back to us. Separately, we also want + // confirmation that our own ~Coroutine() destructor ran. To solve this, we put a + // DisposalResults object on the stack and set a pointer to it in the Coroutine object. This + // indicates to unhandled_exception() and ~Coroutine() where to store the results of the + // destruction operation. + DisposalResults disposalResults; + maybeDisposalResults = &disposalResults; + + // Need to save this while `unwindDetector` is still valid. + bool shouldRethrow = !unwindDetector.isUnwinding(); + + do { + // Clang's implementation of the Coroutines TS does not destroy the Coroutine object or + // deallocate the coroutine frame if a destructor of an object on the frame threw an + // exception. This is despite the fact that it delivered the exception to _us_ via + // unhandled_exception(). Anyway, it appears we can work around this by running + // coroutine.destroy() a second time. + // + // On Clang, `disposalResults.exception != nullptr` implies `!disposalResults.destructorRan`. + // We could optimize out the separate `destructorRan` flag if we verify that other compilers + // behave the same way. + coroutine.destroy(); + } while (!disposalResults.destructorRan); + + // WARNING: `this` is now a dangling pointer. + + KJ_IF_MAYBE(exception, disposalResults.exception) { + if (shouldRethrow) { + kj::throwFatalException(kj::mv(*exception)); + } else { + // An exception is already unwinding the stack, so throwing this secondary exception would + // call std::terminate(). + } + } +} + +CoroutineBase::AwaiterBase::AwaiterBase(OwnPromiseNode node): node(kj::mv(node)) {} +CoroutineBase::AwaiterBase::AwaiterBase(AwaiterBase&&) = default; +CoroutineBase::AwaiterBase::~AwaiterBase() noexcept(false) { + // Make sure it's safe to generate an async stack trace between now and when the Coroutine is + // destroyed. + KJ_IF_MAYBE(coroutineEvent, maybeCoroutineEvent) { + coroutineEvent->promiseNodeForTrace = nullptr; + } + + unwindDetector.catchExceptionsIfUnwinding([this]() { + // No need to check for a moved-from state, node will just ignore the nullification. + node = nullptr; + }); +} + +void CoroutineBase::AwaiterBase::getImpl(ExceptionOrValue& result, void* awaitedAt) { + node->get(result); + + KJ_IF_MAYBE(exception, result.exception) { + // Manually extend the stack trace with the instruction address where the co_await occurred. + exception->addTrace(awaitedAt); + + // Pass kj::maxValue for ignoreCount here so that `throwFatalException()` dosen't try to + // extend the stack trace. There's no point in extending the trace beyond the single frame we + // added above, as the rest of the trace will always be async framework stuff that no one wants + // to see. + kj::throwFatalException(kj::mv(*exception), kj::maxValue); + } +} + +bool CoroutineBase::AwaiterBase::awaitSuspendImpl(CoroutineBase& coroutineEvent) { + node->setSelfPointer(&node); + node->onReady(&coroutineEvent); + + if (coroutineEvent.hasSuspendedAtLeastOnce && coroutineEvent.isNext()) { + // The result is immediately ready and this coroutine is running on the event loop's stack, not + // a user code stack. Let's cancel our event and immediately resume. It's important that we + // don't perform this optimization if this is the first suspension, because our caller may + // depend on running code before this promise's continuations fire. + coroutineEvent.disarm(); + + // We can resume ourselves by returning false. This accomplishes the same thing as if we had + // returned true from await_ready(). + return false; + } else { + // Otherwise, we must suspend. Store a reference to the promise we're waiting on for tracing + // purposes; coroutineEvent.fire() and/or ~Adapter() will null this out. + coroutineEvent.promiseNodeForTrace = *node; + maybeCoroutineEvent = coroutineEvent; + + coroutineEvent.hasSuspendedAtLeastOnce = true; + + return true; + } +} + +// --------------------------------------------------------- +// Helpers for coCapture() + +void throwMultipleCoCaptureInvocations() { + KJ_FAIL_REQUIRE("Attempted to invoke CaptureForCoroutine functor multiple times"); +} + +} // namespace _ (private) + +#endif // KJ_HAS_COROUTINE + } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/async.h b/libs/EXTERNAL/capnproto/c++/src/kj/async.h index d6e503a724f..564b5171549 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/async.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/async.h @@ -22,8 +22,8 @@ #pragma once #include "async-prelude.h" -#include "exception.h" -#include "refcount.h" +#include +#include KJ_BEGIN_HEADER @@ -63,6 +63,57 @@ using PromiseForResult = _::ReducePromises<_::ReturnType>; // T. If T is void, then the promise is for the result of calling Func with no arguments. If // Func itself returns a promise, the promises are joined, so you never get Promise>. +// ======================================================================================= + +class AsyncObject { + // You may optionally inherit privately from this to indicate that the type is a KJ async object, + // meaning it deals with KJ async I/O making it tied to a specific thread and event loop. This + // enables some additional debug checks, but does not otherwise have any effect on behavior as + // long as there are no bugs. + // + // (We prefer inheritance rather than composition here because inheriting an empty type adds zero + // size to the derived class.) + +public: + ~AsyncObject(); + +private: + KJ_NORETURN(static void failed() noexcept); +}; + +class DisallowAsyncDestructorsScope { + // Create this type on the stack in order to specify that during its scope, no KJ async objects + // should be destroyed. If AsyncObject's destructor is called in this scope, the process will + // crash with std::terminate(). + // + // This is useful as a sort of "sanitizer" to catch bugs. When tearing down an object that is + // intended to be passed between threads, you can set up one of these scopes to catch whether + // the object contains any async objects, which are not legal to pass across threads. + +public: + explicit DisallowAsyncDestructorsScope(kj::StringPtr reason); + ~DisallowAsyncDestructorsScope(); + KJ_DISALLOW_COPY_AND_MOVE(DisallowAsyncDestructorsScope); + +private: + kj::StringPtr reason; + DisallowAsyncDestructorsScope* previousValue; + + friend class AsyncObject; +}; + +class AllowAsyncDestructorsScope { + // Negates the effect of DisallowAsyncDestructorsScope. + +public: + AllowAsyncDestructorsScope(); + ~AllowAsyncDestructorsScope(); + KJ_DISALLOW_COPY_AND_MOVE(AllowAsyncDestructorsScope); + +private: + DisallowAsyncDestructorsScope* previousValue; +}; + // ======================================================================================= // Promises @@ -150,8 +201,8 @@ class Promise: protected _::PromiseBase { inline Promise(decltype(nullptr)) {} template - PromiseForResult then(Func&& func, ErrorFunc&& errorHandler = _::PropagateException()) - KJ_WARN_UNUSED_RESULT; + PromiseForResult then(Func&& func, ErrorFunc&& errorHandler = _::PropagateException(), + SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Register a continuation function to be executed when the promise completes. The continuation // (`func`) takes the promised value (an rvalue of type `T`) as its parameter. The continuation // may return a new value; `then()` itself returns a promise for the continuation's eventual @@ -212,11 +263,11 @@ class Promise: protected _::PromiseBase { // You must still wait on the returned promise if you want the task to execute. template - Promise catch_(ErrorFunc&& errorHandler) KJ_WARN_UNUSED_RESULT; + Promise catch_(ErrorFunc&& errorHandler, SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Equivalent to `.then(identityFunc, errorHandler)`, where `identifyFunc` is a function that // just returns its input. - T wait(WaitScope& waitScope); + T wait(WaitScope& waitScope, SourceLocation location = {}); // Run the event loop until the promise is fulfilled, then return its result. If the promise // is rejected, throw an exception. // @@ -256,7 +307,7 @@ class Promise: protected _::PromiseBase { // switches back to the main stack in order to run the event loop, returning to the fiber's stack // once the awaited promise resolves. - bool poll(WaitScope& waitScope); + bool poll(WaitScope& waitScope, SourceLocation location = {}); // Returns true if a call to wait() would complete without blocking, false if it would block. // // If the promise is not yet resolved, poll() will pump the event loop and poll for I/O in an @@ -271,19 +322,19 @@ class Promise: protected _::PromiseBase { // // poll() is not supported in fibers; it will throw an exception. - ForkedPromise fork() KJ_WARN_UNUSED_RESULT; + ForkedPromise fork(SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Forks the promise, so that multiple different clients can independently wait on the result. // `T` must be copy-constructable for this to work. Or, in the special case where `T` is // `Own`, `U` must have a method `Own addRef()` which returns a new reference to the same // (or an equivalent) object (probably implemented via reference counting). - _::SplitTuplePromise split(); + _::SplitTuplePromise split(SourceLocation location = {}); // Split a promise for a tuple into a tuple of promises. // // E.g. if you have `Promise>`, `split()` returns // `kj::Tuple, Promise>`. - Promise exclusiveJoin(Promise&& other) KJ_WARN_UNUSED_RESULT; + Promise exclusiveJoin(Promise&& other, SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Return a new promise that resolves when either the original promise resolves or `other` // resolves (whichever comes first). The promise that didn't resolve first is canceled. @@ -298,8 +349,9 @@ class Promise: protected _::PromiseBase { // runs -- after calling then(), use attach() to add necessary objects to the result. template - Promise eagerlyEvaluate(ErrorFunc&& errorHandler) KJ_WARN_UNUSED_RESULT; - Promise eagerlyEvaluate(decltype(nullptr)) KJ_WARN_UNUSED_RESULT; + Promise eagerlyEvaluate(ErrorFunc&& errorHandler, SourceLocation location = {}) + KJ_WARN_UNUSED_RESULT; + Promise eagerlyEvaluate(decltype(nullptr), SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Force eager evaluation of this promise. Use this if you are going to hold on to the promise // for awhile without consuming the result, but you want to make sure that the system actually // processes it. @@ -328,7 +380,7 @@ class Promise: protected _::PromiseBase { // This method does NOT consume the promise as other methods do. private: - Promise(bool, Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {} + Promise(bool, _::OwnPromiseNode&& node): PromiseBase(kj::mv(node)) {} // Second parameter prevent ambiguity with immediate-value constructor. friend class _::PromiseNode; @@ -357,7 +409,7 @@ class ForkedPromise { friend class EventLoop; }; -constexpr _::Void READY_NOW = _::Void(); +constexpr _::ReadyNow READY_NOW = _::ReadyNow(); // Use this when you need a Promise that is already fulfilled -- this value can be implicitly // cast to `Promise`. @@ -366,6 +418,11 @@ constexpr _::NeverDone NEVER_DONE = _::NeverDone(); // implicitly converted to any promise type. You may also call `NEVER_DONE.wait()` to wait // forever (useful for servers). +template +Promise constPromise(); +// Construct a Promise which resolves to the given constant value. This function is equivalent to +// `Promise(value)` except that it avoids an allocation. + template PromiseForResult evalLater(Func&& func) KJ_WARN_UNUSED_RESULT; // Schedule for the given zero-parameter function to be executed in the event loop at some @@ -424,7 +481,8 @@ PromiseForResult retryOnDisconnect(Func&& func) KJ_WARN_UNUSED_RESUL // with the retry logic added. template -PromiseForResult startFiber(size_t stackSize, Func&& func) KJ_WARN_UNUSED_RESULT; +PromiseForResult startFiber( + size_t stackSize, Func&& func, SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Executes `func()` in a fiber, returning a promise for the eventual reseult. `func()` will be // passed a `WaitScope&` as its parameter, allowing it to call `.wait()` on promises. Thus, `func()` // can be written in a synchronous, blocking style, instead of using `.then()`. This is often much @@ -450,7 +508,7 @@ class FiberPool final { public: explicit FiberPool(size_t stackSize); ~FiberPool() noexcept(false); - KJ_DISALLOW_COPY(FiberPool); + KJ_DISALLOW_COPY_AND_MOVE(FiberPool); void setMaxFreelist(size_t count); // Set the maximum number of stacks to add to the freelist. If the freelist is full, stacks will @@ -463,7 +521,8 @@ class FiberPool final { // feature is only supported on Linux (the flag has no effect on other operating systems). template - PromiseForResult startFiber(Func&& func) const KJ_WARN_UNUSED_RESULT; + PromiseForResult startFiber( + Func&& func, SourceLocation location = {}) const KJ_WARN_UNUSED_RESULT; // Executes `func()` in a fiber from this pool, returning a promise for the eventual result. // `func()` will be passed a `WaitScope&` as its parameter, allowing it to call `.wait()` on // promises. Thus, `func()` can be written in a synchronous, blocking style, instead of @@ -496,8 +555,19 @@ class FiberPool final { }; template -Promise> joinPromises(Array>&& promises); -// Join an array of promises into a promise for an array. +Promise> joinPromises(Array>&& promises, SourceLocation location = {}); +// Join an array of promises into a promise for an array. Trailing continuations on promises are not +// evaluated until all promises have settled. Exceptions are propagated only after the last promise +// has settled. +// +// TODO(cleanup): It is likely that `joinPromisesFailFast()` is what everyone should be using. +// Deprecate this function. + +template +Promise> joinPromisesFailFast(Array>&& promises, SourceLocation location = {}); +// Join an array of promises into a promise for an array. Trailing continuations on promises are +// evaluated eagerly. If any promise results in an exception, the exception is immediately +// propagated to the returned join promise. // ======================================================================================= // Hack for creating a lambda that holds an owned pointer. @@ -519,6 +589,10 @@ class CaptureByMove { MovedParam param; }; +template +inline CaptureByMove> mvCapture(MovedParam&& param, Func&& func) + KJ_DEPRECATED("Use C++14 generalized captures instead."); + template inline CaptureByMove> mvCapture(MovedParam&& param, Func&& func) { // Hack to create a "lambda" which captures a variable by moving it rather than copying or @@ -534,10 +608,121 @@ inline CaptureByMove> mvCapture(MovedParam&& param, Func return CaptureByMove>(kj::fwd(func), kj::mv(param)); } +// ======================================================================================= +// Hack for safely using a lambda as a coroutine. + +#if KJ_HAS_COROUTINE + +namespace _ { + +void throwMultipleCoCaptureInvocations(); + +template +struct CaptureForCoroutine { + kj::Maybe maybeFunctor; + + explicit CaptureForCoroutine(Functor&& f) : maybeFunctor(kj::mv(f)) {} + + template + static auto coInvoke(Functor functor, Args&&... args) + -> decltype(functor(kj::fwd(args)...)) { + // Since the functor is now in the local scope and no longer a member variable, it will be + // persisted in the coroutine state. + + // Note that `co_await functor(...)` can still return `void`. It just happens that + // `co_return voidReturn();` is explicitly allowed. + co_return co_await functor(kj::fwd(args)...); + } + + template + auto operator()(Args&&... args) { + if (maybeFunctor == nullptr) { + throwMultipleCoCaptureInvocations(); + } + auto localFunctor = kj::mv(*kj::_::readMaybe(maybeFunctor)); + maybeFunctor = nullptr; + return coInvoke(kj::mv(localFunctor), kj::fwd(args)...); + } +}; + +} // namespace _ + +template +auto coCapture(Functor&& f) { + // Assuming `f()` returns a Promise `p`, wrap `f` in such a way that it will outlive its + // returned Promise. Note that the returned object may only be invoked once. + // + // This function is meant to help address this pain point with functors that return a coroutine: + // https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rcoro-capture + // + // The two most common patterns where this may be useful look like so: + // ``` + // void addTask(Value myValue) { + // auto myFun = [myValue]() -> kj::Promise { + // ... + // co_return; + // }; + // tasks.add(myFun()); + // } + // ``` + // and + // ``` + // kj::Promise afterPromise(kj::Promise promise, Value myValue) { + // auto myFun = [myValue]() -> kj::Promise { + // ... + // co_return; + // }; + // return promise.then(kj::mv(myFun)); + // } + // ``` + // + // Note that there are potentially more optimal alternatives to both of these patterns: + // ``` + // void addTask(Value myValue) { + // auto myFun = [](auto myValue) -> kj::Promise { + // ... + // co_return; + // }; + // tasks.add(myFun(myValue)); + // } + // ``` + // and + // ``` + // kj::Promise afterPromise(kj::Promise promise, Value myValue) { + // auto myFun = [&]() -> kj::Promise { + // ... + // co_return; + // }; + // co_await promise; + // co_await myFun(); + // co_return; + // } + // ``` + // + // For situations where you are trying to capture a specific local variable, kj::mvCapture() can + // also be useful: + // ``` + // kj::Promise reactToPromise(kj::Promise promise) { + // BigA a; + // TinyB b; + // + // doSomething(a, b); + // return promise.then(kj::mvCapture(b, [](TinyB b, MyType type) -> kj::Promise { + // ... + // co_return; + // }); + // } + // ``` + + return _::CaptureForCoroutine(kj::mv(f)); +} + +#endif // KJ_HAS_COROUTINE + // ======================================================================================= // Advanced promise construction -class PromiseRejector { +class PromiseRejector: private AsyncObject { // Superclass of PromiseFulfiller containing the non-typed methods. Useful when you only really // need to be able to reject a promise, and you need to operate on fulfillers of different types. public: @@ -613,7 +798,7 @@ struct PromiseFulfillerPair { }; template -PromiseFulfillerPair newPromiseAndFulfiller(); +PromiseFulfillerPair newPromiseAndFulfiller(SourceLocation location = {}); // Construct a Promise and a separate PromiseFulfiller which can be used to fulfill the promise. // If the PromiseFulfiller is destroyed before either of its methods are called, the Promise is // implicitly rejected. @@ -678,7 +863,7 @@ PromiseCrossThreadFulfillerPair newPromiseAndCrossThreadFulfiller(); // ======================================================================================= // Canceler -class Canceler { +class Canceler: private AsyncObject { // A Canceler can wrap some set of Promises and then forcefully cancel them on-demand, or // implicitly when the Canceler is destroyed. // @@ -692,7 +877,7 @@ class Canceler { // Canceler and using it to wrap promises before returning them to callers. When Bob is // destroyed, the Canceler is destroyed too, and all promises Bob wrapped with it throw errors. // - // Note that another common strategy for cancelation is to use exclusiveJoin() to join a promise + // Note that another common strategy for cancellation is to use exclusiveJoin() to join a promise // with some "cancellation promise" which only resolves if the operation should be canceled. The // cancellation promise could itself be created by newPromiseAndFulfiller(), and thus // calling the PromiseFulfiller cancels the operation. There is a major problem with this @@ -704,7 +889,7 @@ class Canceler { public: inline Canceler() {} ~Canceler() noexcept(false); - KJ_DISALLOW_COPY(Canceler); + KJ_DISALLOW_COPY_AND_MOVE(Canceler); template Promise wrap(Promise promise) { @@ -783,7 +968,7 @@ class Canceler::AdapterImpl: public AdapterBase { // ======================================================================================= // TaskSet -class TaskSet { +class TaskSet: private AsyncObject { // Holds a collection of Promises and ensures that each executes to completion. Memory // associated with each promise is automatically freed when the promise completes. Destroying // the TaskSet itself automatically cancels all unfinished promises. @@ -791,7 +976,7 @@ class TaskSet { // This is useful for "daemon" objects that perform background tasks which aren't intended to // fulfill any particular external promise, but which may need to be canceled (and thus can't // use `Promise::detach()`). The daemon object holds a TaskSet to collect these tasks it is - // working on. This way, if the daemon itself is destroyed, the TaskSet is detroyed as well, + // working on. This way, if the daemon itself is destroyed, the TaskSet is destroyed as well, // and everything the daemon is doing is canceled. public: @@ -800,7 +985,7 @@ class TaskSet { virtual void taskFailed(kj::Exception&& exception) = 0; }; - TaskSet(ErrorHandler& errorHandler); + TaskSet(ErrorHandler& errorHandler, SourceLocation location = {}); // `errorHandler` will be executed any time a task throws an exception, and will execute within // the given EventLoop. @@ -818,12 +1003,23 @@ class TaskSet { // Returns a promise that fulfills the next time the TaskSet is empty. Only one such promise can // exist at a time. + void clear(); + // Cancel all tasks. + // + // As always, it is not safe to cancel the task that is currently running, so you could not call + // this from inside a task in the TaskSet. However, it IS safe to call this from the + // `taskFailed()` callback. + // + // Calling this will always trigger onEmpty(), if anyone is listening. + private: class Task; + using OwnTask = Own; TaskSet::ErrorHandler& errorHandler; - Maybe> tasks; + Maybe tasks; Maybe>> emptyFulfiller; + SourceLocation location; }; // ======================================================================================= @@ -865,7 +1061,7 @@ class Executor { // for "try" versions... template - PromiseForResult executeAsync(Func&& func) const; + PromiseForResult executeAsync(Func&& func, SourceLocation location = {}) const; // Call from any thread to request that the given function be executed on the executor's thread, // returning a promise for the result. // @@ -907,7 +1103,8 @@ class Executor { // call provides E-Order in the same way as Cap'n Proto.) template - _::UnwrapPromise> executeSync(Func&& func) const; + _::UnwrapPromise> executeSync( + Func&& func, SourceLocation location = {}) const; // Schedules `func()` to execute on the executor thread, and then blocks the requesting thread // until `func()` completes. If `func()` returns a Promise, then the wait will continue until // that promise resolves, and the final result will be returned to the requesting thread. @@ -1084,9 +1281,9 @@ class EventLoop { void poll(); friend void _::detach(kj::Promise&& promise); - friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, - WaitScope& waitScope); - friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope); + friend void _::waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, + WaitScope& waitScope, SourceLocation location); + friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location); friend class _::Event; friend class WaitScope; friend class Executor; @@ -1109,10 +1306,12 @@ class WaitScope { public: inline explicit WaitScope(EventLoop& loop): loop(loop) { loop.enterScope(); } inline ~WaitScope() { if (fiber == nullptr) loop.leaveScope(); } - KJ_DISALLOW_COPY(WaitScope); + KJ_DISALLOW_COPY_AND_MOVE(WaitScope); - void poll(); - // Pumps the event queue and polls for I/O until there's nothing left to do (without blocking). + uint poll(uint maxTurnCount = maxValue); + // Pumps the event queue and polls for I/O until there's nothing left to do (without blocking) or + // the maximum turn count has been reached. Returns the number of events popped off the event + // queue. // // Not supported in fibers. @@ -1172,9 +1371,9 @@ class WaitScope { friend class EventLoop; friend class _::FiberBase; - friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, - WaitScope& waitScope); - friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope); + friend void _::waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, + WaitScope& waitScope, SourceLocation location); + friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location); }; } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/cidr.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/cidr.c++ new file mode 100644 index 00000000000..6a1fa32e40e --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/cidr.c++ @@ -0,0 +1,179 @@ +// Copyright (c) 2013-2017 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if _WIN32 +// Request Vista-level APIs. +#include +#endif + +#include "debug.h" +#include "cidr.h" + +#if _WIN32 +#include +#include +#include +#include +#define inet_pton InetPtonA +#define inet_ntop InetNtopA +#include +#define dup _dup +#else +#include +#include +#endif + +#if __FreeBSD__ +#include +#endif + +namespace kj { + +CidrRange::CidrRange(StringPtr pattern) { + size_t slashPos = KJ_REQUIRE_NONNULL(pattern.findFirst('/'), "invalid CIDR", pattern); + + bitCount = pattern.slice(slashPos + 1).parseAs(); + + KJ_STACK_ARRAY(char, addr, slashPos + 1, 128, 128); + memcpy(addr.begin(), pattern.begin(), slashPos); + addr[slashPos] = '\0'; + + if (pattern.findFirst(':') == nullptr) { + family = AF_INET; + KJ_REQUIRE(bitCount <= 32, "invalid CIDR", pattern); + } else { + family = AF_INET6; + KJ_REQUIRE(bitCount <= 128, "invalid CIDR", pattern); + } + + KJ_ASSERT(inet_pton(family, addr.begin(), bits) > 0, "invalid CIDR", pattern); + zeroIrrelevantBits(); +} + +CidrRange::CidrRange(int family, ArrayPtr bits, uint bitCount) + : family(family), bitCount(bitCount) { + if (family == AF_INET) { + KJ_REQUIRE(bitCount <= 32); + } else { + KJ_REQUIRE(bitCount <= 128); + } + KJ_REQUIRE(bits.size() * 8 >= bitCount); + size_t byteCount = (bitCount + 7) / 8; + memcpy(this->bits, bits.begin(), byteCount); + memset(this->bits + byteCount, 0, sizeof(this->bits) - byteCount); + + zeroIrrelevantBits(); +} + +CidrRange CidrRange::inet4(ArrayPtr bits, uint bitCount) { + return CidrRange(AF_INET, bits, bitCount); +} +CidrRange CidrRange::inet6( + ArrayPtr prefix, ArrayPtr suffix, + uint bitCount) { + KJ_REQUIRE(prefix.size() + suffix.size() <= 8); + + byte bits[16] = { 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, }; + + for (size_t i: kj::indices(prefix)) { + bits[i * 2] = prefix[i] >> 8; + bits[i * 2 + 1] = prefix[i] & 0xff; + } + + byte* suffixBits = bits + (16 - suffix.size() * 2); + for (size_t i: kj::indices(suffix)) { + suffixBits[i * 2] = suffix[i] >> 8; + suffixBits[i * 2 + 1] = suffix[i] & 0xff; + } + + return CidrRange(AF_INET6, bits, bitCount); +} + +bool CidrRange::matches(const struct sockaddr* addr) const { + const byte* otherBits; + + switch (family) { + case AF_INET: + if (addr->sa_family == AF_INET6) { + otherBits = reinterpret_cast(addr)->sin6_addr.s6_addr; + static constexpr byte V6MAPPED[12] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff }; + if (memcmp(otherBits, V6MAPPED, sizeof(V6MAPPED)) == 0) { + // We're an ipv4 range and the address is ipv6, but it's a "v6 mapped" address, meaning + // it's equivalent to an ipv4 address. Try to match against the ipv4 part. + otherBits = otherBits + sizeof(V6MAPPED); + } else { + return false; + } + } else if (addr->sa_family == AF_INET) { + otherBits = reinterpret_cast( + &reinterpret_cast(addr)->sin_addr.s_addr); + } else { + return false; + } + + break; + + case AF_INET6: + if (addr->sa_family != AF_INET6) return false; + + otherBits = reinterpret_cast(addr)->sin6_addr.s6_addr; + break; + + default: + KJ_UNREACHABLE; + } + + if (memcmp(bits, otherBits, bitCount / 8) != 0) return false; + + return bitCount == 128 || + bits[bitCount / 8] == (otherBits[bitCount / 8] & (0xff00 >> (bitCount % 8))); +} + +bool CidrRange::matchesFamily(int family) const { + switch (family) { + case AF_INET: + return this->family == AF_INET; + case AF_INET6: + // Even if we're a v4 CIDR, we can match v6 addresses in the v4-mapped range. + return true; + default: + return false; + } +} + +String CidrRange::toString() const { + char result[128]; + KJ_ASSERT(inet_ntop(family, (void*)bits, result, sizeof(result)) == result); + return kj::str(result, '/', bitCount); +} + +void CidrRange::zeroIrrelevantBits() { + // Mask out insignificant bits of partial byte. + if (bitCount < 128) { + bits[bitCount / 8] &= 0xff00 >> (bitCount % 8); + + // Zero the remaining bytes. + size_t n = bitCount / 8 + 1; + memset(bits + n, 0, sizeof(bits) - n); + } +} + +} // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/cidr.h b/libs/EXTERNAL/capnproto/c++/src/kj/cidr.h new file mode 100644 index 00000000000..b334ecc7d4c --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/cidr.h @@ -0,0 +1,62 @@ + +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "common.h" +#include + +KJ_BEGIN_HEADER + +struct sockaddr; + +namespace kj { + +class CidrRange { +public: + CidrRange(StringPtr pattern); + + static CidrRange inet4(ArrayPtr bits, uint bitCount); + static CidrRange inet6(ArrayPtr prefix, ArrayPtr suffix, + uint bitCount); + // Zeros are inserted between `prefix` and `suffix` to extend the address to 128 bits. + + uint getSpecificity() const { return bitCount; } + + bool matches(const struct sockaddr* addr) const; + bool matchesFamily(int family) const; + + String toString() const; + +private: + int family; + byte bits[16]; + uint bitCount; // how many bits in `bits` need to match + + CidrRange(int family, ArrayPtr bits, uint bitCount); + + void zeroIrrelevantBits(); +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/common-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/common-test.c++ index 0924d41dabc..9785612562b 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/common-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/common-test.c++ @@ -47,7 +47,7 @@ struct ImplicitToInt { struct Immovable { Immovable() = default; - KJ_DISALLOW_COPY(Immovable); + KJ_DISALLOW_COPY_AND_MOVE(Immovable); }; struct CopyOrMove { @@ -97,6 +97,43 @@ TEST(Common, Maybe) { } } + { + Maybe> m = kj::heap(123); + EXPECT_FALSE(m == nullptr); + EXPECT_TRUE(m != nullptr); + KJ_IF_MAYBE(v, m) { + EXPECT_EQ(123, (*v)->i); + } else { + ADD_FAILURE(); + } + KJ_IF_MAYBE(v, mv(m)) { + EXPECT_EQ(123, (*v)->i); + } else { + ADD_FAILURE(); + } + // We have moved the kj::Own away, so this should give us the default and leave the Maybe empty. + EXPECT_EQ(456, m.orDefault(heap(456))->i); + EXPECT_TRUE(m == nullptr); + + bool ranLazy = false; + EXPECT_EQ(123, mv(m).orDefault([&] { + ranLazy = true; + return heap(123); + })->i); + EXPECT_TRUE(ranLazy); + EXPECT_TRUE(m == nullptr); + + m = heap(123); + EXPECT_TRUE(m != nullptr); + ranLazy = false; + EXPECT_EQ(123, mv(m).orDefault([&] { + ranLazy = true; + return heap(456); + })->i); + EXPECT_FALSE(ranLazy); + EXPECT_TRUE(m == nullptr); + } + { Maybe empty; int defaultValue = 5; @@ -418,9 +455,99 @@ TEST(Common, MaybeConstness) { } } +#if __GNUC__ +TEST(Common, MaybeUnwrapOrReturn) { + { + auto func = [](Maybe i) -> int { + int& j = KJ_UNWRAP_OR_RETURN(i, -1); + KJ_EXPECT(&j == &KJ_ASSERT_NONNULL(i)); + return j + 2; + }; + + KJ_EXPECT(func(123) == 125); + KJ_EXPECT(func(nullptr) == -1); + } + + { + auto func = [&](Maybe maybe) -> int { + String str = KJ_UNWRAP_OR_RETURN(kj::mv(maybe), -1); + return str.parseAs(); + }; + + KJ_EXPECT(func(kj::str("123")) == 123); + KJ_EXPECT(func(nullptr) == -1); + } + + // Test void return. + { + int val = 0; + auto func = [&](Maybe i) { + val = KJ_UNWRAP_OR_RETURN(i); + }; + + func(123); + KJ_EXPECT(val == 123); + val = 321; + func(nullptr); + KJ_EXPECT(val == 321); + } + + // Test KJ_UNWRAP_OR + { + bool wasNull = false; + auto func = [&](Maybe i) -> int { + int& j = KJ_UNWRAP_OR(i, { + wasNull = true; + return -1; + }); + KJ_EXPECT(&j == &KJ_ASSERT_NONNULL(i)); + return j + 2; + }; + + KJ_EXPECT(func(123) == 125); + KJ_EXPECT(!wasNull); + KJ_EXPECT(func(nullptr) == -1); + KJ_EXPECT(wasNull); + } + + { + bool wasNull = false; + auto func = [&](Maybe maybe) -> int { + String str = KJ_UNWRAP_OR(kj::mv(maybe), { + wasNull = true; + return -1; + }); + return str.parseAs(); + }; + + KJ_EXPECT(func(kj::str("123")) == 123); + KJ_EXPECT(!wasNull); + KJ_EXPECT(func(nullptr) == -1); + KJ_EXPECT(wasNull); + } + + // Test void return. + { + int val = 0; + auto func = [&](Maybe i) { + val = KJ_UNWRAP_OR(i, { + return; + }); + }; + + func(123); + KJ_EXPECT(val == 123); + val = 321; + func(nullptr); + KJ_EXPECT(val == 321); + } + +} +#endif + class Foo { public: - KJ_DISALLOW_COPY(Foo); + KJ_DISALLOW_COPY_AND_MOVE(Foo); virtual ~Foo() {} protected: Foo() = default; @@ -429,14 +556,14 @@ protected: class Bar: public Foo { public: Bar() = default; - KJ_DISALLOW_COPY(Bar); + KJ_DISALLOW_COPY_AND_MOVE(Bar); virtual ~Bar() {} }; class Baz: public Foo { public: Baz() = delete; - KJ_DISALLOW_COPY(Baz); + KJ_DISALLOW_COPY_AND_MOVE(Baz); virtual ~Baz() {} }; @@ -684,6 +811,10 @@ KJ_TEST("ArrayPtr operator ==") { ArrayPtr({"foo", "baz"}))); KJ_EXPECT((ArrayPtr({"foo", "bar"}) != ArrayPtr({"foo"}))); + + // operator== should not use memcmp for double elements. + double d[1] = { nan() }; + KJ_EXPECT(ArrayPtr(d, 1) != ArrayPtr(d, 1)); } KJ_TEST("kj::range()") { @@ -701,31 +832,97 @@ KJ_TEST("kj::range()") { } KJ_TEST("kj::defer()") { - bool executed; + { + // rvalue reference + bool executed = false; + { + auto deferred = kj::defer([&executed]() { + executed = true; + }); + KJ_EXPECT(!executed); + } + + KJ_EXPECT(executed); + } - // rvalue reference { - executed = false; - auto deferred = kj::defer([&executed]() { + // lvalue reference + bool executed = false; + auto executor = [&executed]() { executed = true; - }); - KJ_EXPECT(!executed); - } + }; - KJ_EXPECT(executed); + { + auto deferred = kj::defer(executor); + KJ_EXPECT(!executed); + } - // lvalue reference - auto executor = [&executed]() { - executed = true; - }; + KJ_EXPECT(executed); + } { - executed = false; - auto deferred = kj::defer(executor); + // Cancellation via `cancel()`. + bool executed = false; + { + auto deferred = kj::defer([&executed]() { + executed = true; + }); + KJ_EXPECT(!executed); + + // Cancel and release the functor. + deferred.cancel(); + KJ_EXPECT(!executed); + } + KJ_EXPECT(!executed); } - KJ_EXPECT(executed); + { + // Execution via `run()`. + size_t runCount = 0; + { + auto deferred = kj::defer([&runCount](){ + ++runCount; + }); + + // Run and release the functor. + deferred.run(); + KJ_EXPECT(runCount == 1); + } + + // `deferred` is already been run, so nothing is run when we destruct it. + KJ_EXPECT(runCount == 1); + } + +} + +KJ_TEST("kj::ArrayPtr startsWith / endsWith / findFirst / findLast") { + // Note: char-/byte- optimized versions are covered by string-test.c++. + + int rawArray[] = {12, 34, 56, 34, 12}; + ArrayPtr arr(rawArray); + + KJ_EXPECT(arr.startsWith({12, 34})); + KJ_EXPECT(arr.startsWith({12, 34, 56})); + KJ_EXPECT(!arr.startsWith({12, 34, 56, 78})); + KJ_EXPECT(arr.startsWith({12, 34, 56, 34, 12})); + KJ_EXPECT(!arr.startsWith({12, 34, 56, 34, 12, 12})); + + KJ_EXPECT(arr.endsWith({34, 12})); + KJ_EXPECT(arr.endsWith({56, 34, 12})); + KJ_EXPECT(!arr.endsWith({78, 56, 34, 12})); + KJ_EXPECT(arr.endsWith({12, 34, 56, 34, 12})); + KJ_EXPECT(!arr.endsWith({12, 12, 34, 56, 34, 12})); + + KJ_EXPECT(arr.findFirst(12).orDefault(100) == 0); + KJ_EXPECT(arr.findFirst(34).orDefault(100) == 1); + KJ_EXPECT(arr.findFirst(56).orDefault(100) == 2); + KJ_EXPECT(arr.findFirst(78).orDefault(100) == 100); + + KJ_EXPECT(arr.findLast(12).orDefault(100) == 4); + KJ_EXPECT(arr.findLast(34).orDefault(100) == 3); + KJ_EXPECT(arr.findLast(56).orDefault(100) == 2); + KJ_EXPECT(arr.findLast(78).orDefault(100) == 100); } } // namespace diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/common.h b/libs/EXTERNAL/capnproto/c++/src/kj/common.h index 2caba972f5f..42ecbbce61e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/common.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/common.h @@ -60,12 +60,18 @@ #define KJ_HAS_COMPILER_FEATURE(x) 0 #endif +#if defined(_MSVC_LANG) && !defined(__clang__) +#define KJ_CPP_STD _MSVC_LANG +#else +#define KJ_CPP_STD __cplusplus +#endif + KJ_BEGIN_HEADER #ifndef KJ_NO_COMPILER_CHECK // Technically, __cplusplus should be 201402L for C++14, but GCC 4.9 -- which is supported -- still // had it defined to 201300L even with -std=c++14. -#if __cplusplus < 201300L && !__CDT_PARSER__ && !_MSC_VER +#if KJ_CPP_STD < 201300L && !__CDT_PARSER__ #error "This code requires C++14. Either your compiler does not support it or it is not enabled." #ifdef __GNUC__ // Compiler claims compatibility with GCC, so presumably supports -std. @@ -77,7 +83,7 @@ KJ_BEGIN_HEADER #if __clang__ #if __clang_major__ < 5 #warning "This library requires at least Clang 5.0." - #elif __cplusplus >= 201402L && !__has_include() + #elif KJ_CPP_STD >= 201402L && !__has_include() #warning "Your compiler supports C++14 but your C++ standard library does not. If your "\ "system has libc++ installed (as should be the case on e.g. Mac OSX), try adding "\ "-stdlib=libc++ to your CXXFLAGS." @@ -99,9 +105,11 @@ KJ_BEGIN_HEADER #endif #include +#include #include +#include -#if __linux__ && __cplusplus > 201200L +#if __linux__ && KJ_CPP_STD > 201200L // Hack around stdlib bug with C++14 that exists on some Linux systems. // Apparently in this mode the C library decides not to define gets() but the C++ library still // tries to import it into the std namespace. This bug has been fixed at the source but is still @@ -109,10 +117,19 @@ KJ_BEGIN_HEADER #undef _GLIBCXX_HAVE_GETS #endif -#if defined(_MSC_VER) +#if _WIN32 +// Windows likes to define macros for min() and max(). We just can't deal with this. +// If windows.h was included already, undef these. +#undef min +#undef max +// If windows.h was not included yet, define the macro that prevents min() and max() from being +// defined. #ifndef NOMINMAX #define NOMINMAX 1 #endif +#endif + +#if defined(_MSC_VER) #include // __popcnt #endif @@ -174,7 +191,22 @@ typedef unsigned char byte; #define KJ_DISALLOW_COPY(classname) \ classname(const classname&) = delete; \ classname& operator=(const classname&) = delete -// Deletes the implicit copy constructor and assignment operator. +// Deletes the implicit copy constructor and assignment operator. This inhibits the compiler from +// generating the implicit move constructor and assignment operator for this class, but allows the +// code author to supply them, if they make sense to implement. +// +// This macro should not be your first choice. Instead, prefer using KJ_DISALLOW_COPY_AND_MOVE, and only use +// this macro when you have determined that you must implement move semantics for your type. + +#define KJ_DISALLOW_COPY_AND_MOVE(classname) \ + classname(const classname&) = delete; \ + classname& operator=(const classname&) = delete; \ + classname(classname&&) = delete; \ + classname& operator=(classname&&) = delete +// Deletes the implicit copy and move constructors and assignment operators. This is useful in cases +// where the code author wants to provide an additional compile-time guard against subsequent +// maintainers casually adding move operations. This is particularly useful when implementing RAII +// classes that are intended to be completely immobile. #ifdef __GNUC__ #define KJ_LIKELY(condition) __builtin_expect(condition, true) @@ -249,7 +281,7 @@ typedef unsigned char byte; #define KJ_UNUSED_MEMBER #endif -#if __cplusplus > 201703L || (__clang__ && __clang_major__ >= 9 && __cplusplus >= 201103L) +#if KJ_CPP_STD > 201703L || (__clang__ && __clang_major__ >= 9 && KJ_CPP_STD >= 201103L) // Technically this was only added to C++20 but Clang allows it for >= C++11 and spelunking the // attributes manual indicates it first came in with Clang 9. #define KJ_NO_UNIQUE_ADDRESS [[no_unique_address]] @@ -271,10 +303,14 @@ typedef unsigned char byte; #elif __GNUC__ #define KJ_DEPRECATED(reason) \ __attribute__((deprecated)) -#define KJ_UNAVAILABLE(reason) +#define KJ_UNAVAILABLE(reason) = delete +// If the `unavailable` attribute is not supproted, just mark the method deleted, which at least +// makes it a compile-time error to try to call it. Note that on Clang, marking a method deleted +// *and* unavailable unfortunately defeats the purpose of the unavailable annotation, as the +// generic "deleted" error is reported instead. #else #define KJ_DEPRECATED(reason) -#define KJ_UNAVAILABLE(reason) +#define KJ_UNAVAILABLE(reason) = delete // TODO(msvc): Again, here, MSVC prefers a prefix, __declspec(deprecated). #endif @@ -293,8 +329,12 @@ KJ_NORETURN(void unreachable()); } // namespace _ (private) +#if _MSC_VER && !defined(__clang__) && (!defined(_MSVC_TRADITIONAL) || _MSVC_TRADITIONAL) +#define KJ_MSVC_TRADITIONAL_CPP 1 +#endif + #ifdef KJ_DEBUG -#if _MSC_VER && !defined(__clang__) +#if KJ_MSVC_TRADITIONAL_CPP #define KJ_IREQUIRE(condition, ...) \ if (KJ_LIKELY(condition)); else ::kj::_::inlineRequireFailure( \ __FILE__, __LINE__, #condition, "" #__VA_ARGS__, __VA_ARGS__) @@ -408,6 +448,15 @@ KJ_NORETURN(void unreachable()); // ======================================================================================= // Template metaprogramming helpers. +#define KJ_HAS_TRIVIAL_CONSTRUCTOR __is_trivially_constructible +#if __GNUC__ && !__clang__ +#define KJ_HAS_NOTHROW_CONSTRUCTOR __has_nothrow_constructor +#define KJ_HAS_TRIVIAL_DESTRUCTOR __has_trivial_destructor +#else +#define KJ_HAS_NOTHROW_CONSTRUCTOR __is_nothrow_constructible +#define KJ_HAS_TRIVIAL_DESTRUCTOR __is_trivially_destructible +#endif + template struct NoInfer_ { typedef T Type; }; template using NoInfer = typename NoInfer_::Type; // Use NoInfer::Type in place of T for a template function parameter to prevent inference of @@ -560,6 +609,19 @@ template struct IsSameType_ { static constexpr bool val template struct IsSameType_ { static constexpr bool value = true; }; template constexpr bool isSameType() { return IsSameType_::value; } +template constexpr bool isIntegral() { return false; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } + template struct CanConvert_ { static int sfinae(T); @@ -1011,8 +1073,7 @@ inline void dtor(T& location) { // forces the caller to handle the null case in order to satisfy the compiler, thus reliably // preventing null pointer dereferences at runtime. // -// Maybe can be implicitly constructed from T and from nullptr. Additionally, it can be -// implicitly constructed from T*, in which case the pointer is checked for nullness at runtime. +// Maybe can be implicitly constructed from T and from nullptr. // To read the value of a Maybe, do: // // KJ_IF_MAYBE(value, someFuncReturningMaybe()) { @@ -1260,6 +1321,63 @@ inline T* readMaybe(T* ptr) { return ptr; } #define KJ_IF_MAYBE(name, exp) if (auto name = ::kj::_::readMaybe(exp)) +#if __GNUC__ || __clang__ +// These two macros provide a friendly syntax to extract the value of a Maybe or return early. +// +// Use KJ_UNWRAP_OR_RETURN if you just want to return a simple value when the Maybe is null: +// +// int foo(Maybe maybe) { +// int value = KJ_UNWRAP_OR_RETURN(maybe, -1); +// // ... use value ... +// } +// +// For functions returning void, omit the second parameter to KJ_UNWRAP_OR_RETURN: +// +// void foo(Maybe maybe) { +// int value = KJ_UNWRAP_OR_RETURN(maybe); +// // ... use value ... +// } +// +// Use KJ_UNWRAP_OR if you want to execute a block with multiple statements. +// +// int foo(Maybe maybe) { +// int value = KJ_UNWRAP_OR(maybe, { +// KJ_LOG(ERROR, "problem!!!"); +// return -1; +// }); +// // ... use value ... +// } +// +// The block MUST return at the end or you will get a compiler error +// +// Unfortunately, these macros seem impossible to express without using GCC's non-standard +// "statement expressions" extension. IIFEs don't do the trick here because a lambda cannot +// return out of the parent scope. These macros should therefore only be used in projects that +// target GCC or GCC-compatible compilers. +// +// `__GNUC__` is not defined when using LLVM's MSVC-compatible compiler driver `clang-cl` (even +// though clang supports the required extension), hence the additional `|| __clang__`. + +#define KJ_UNWRAP_OR_RETURN(value, ...) \ + (*({ \ + auto _kj_result = ::kj::_::readMaybe(value); \ + if (!_kj_result) { \ + return __VA_ARGS__; \ + } \ + kj::mv(_kj_result); \ + })) + +#define KJ_UNWRAP_OR(value, block) \ + (*({ \ + auto _kj_result = ::kj::_::readMaybe(value); \ + if (!_kj_result) { \ + block; \ + asm("KJ_UNWRAP_OR_block_is_missing_return_statement\n"); \ + } \ + kj::mv(_kj_result); \ + })) +#endif + template class Maybe { // A T, or nullptr. @@ -1683,6 +1801,29 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { KJ_IREQUIRE(start <= end && end <= size_, "Out-of-bounds ArrayPtr::slice()."); return ArrayPtr(ptr + start, end - start); } + inline bool startsWith(const ArrayPtr& other) const { + return other.size() <= size_ && slice(0, other.size()) == other; + } + inline bool endsWith(const ArrayPtr& other) const { + return other.size() <= size_ && slice(size_ - other.size(), size_) == other; + } + + inline Maybe findFirst(const T& match) const { + for (size_t i = 0; i < size_; i++) { + if (ptr[i] == match) { + return i; + } + } + return nullptr; + } + inline Maybe findLast(const T& match) const { + for (size_t i = size_; i--;) { + if (ptr[i] == match) { + return i; + } + } + return nullptr; + } inline ArrayPtr> asBytes() const { // Reinterpret the array as a byte array. This is explicitly legal under C++ aliasing @@ -1700,12 +1841,18 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { inline bool operator==(const ArrayPtr& other) const { if (size_ != other.size_) return false; + if (isIntegral>()) { + if (size_ == 0) return true; + return memcmp(ptr, other.ptr, size_ * sizeof(T)) == 0; + } for (size_t i = 0; i < size_; i++) { if (ptr[i] != other[i]) return false; } return true; } +#if !__cpp_impl_three_way_comparison inline bool operator!=(const ArrayPtr& other) const { return !(*this == other); } +#endif template inline bool operator==(const ArrayPtr& other) const { @@ -1715,8 +1862,10 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { } return true; } +#if !__cpp_impl_three_way_comparison template inline bool operator!=(const ArrayPtr& other) const { return !(*this == other); } +#endif template Array attach(Attachments&&... attachments) const KJ_WARN_UNUSED_RESULT; @@ -1730,6 +1879,49 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { size_t size_; }; +template <> +inline Maybe ArrayPtr::findFirst(const char& c) const { + const char* pos = reinterpret_cast(memchr(ptr, c, size_)); + if (pos == nullptr) { + return nullptr; + } else { + return pos - ptr; + } +} + +template <> +inline Maybe ArrayPtr::findFirst(const char& c) const { + char* pos = reinterpret_cast(memchr(ptr, c, size_)); + if (pos == nullptr) { + return nullptr; + } else { + return pos - ptr; + } +} + +template <> +inline Maybe ArrayPtr::findFirst(const byte& c) const { + const byte* pos = reinterpret_cast(memchr(ptr, c, size_)); + if (pos == nullptr) { + return nullptr; + } else { + return pos - ptr; + } +} + +template <> +inline Maybe ArrayPtr::findFirst(const byte& c) const { + byte* pos = reinterpret_cast(memchr(ptr, c, size_)); + if (pos == nullptr) { + return nullptr; + } else { + return pos - ptr; + } +} + +// glibc has a memrchr() for reverse search but it's non-standard, so we don't bother optimizing +// findLast(), which isn't used much anyway. + template inline constexpr ArrayPtr arrayPtr(T* ptr KJ_LIFETIMEBOUND, size_t size) { // Use this function to construct ArrayPtrs without writing out the type name. @@ -1799,29 +1991,48 @@ namespace _ { // private template class Deferred { public: - inline Deferred(Func&& func): func(kj::fwd(func)), canceled(false) {} - inline ~Deferred() noexcept(false) { if (!canceled) func(); } + Deferred(Func&& func): maybeFunc(kj::fwd(func)) {} + ~Deferred() noexcept(false) { + run(); + } KJ_DISALLOW_COPY(Deferred); - // This move constructor is usually optimized away by the compiler. - inline Deferred(Deferred&& other): func(kj::fwd(other.func)), canceled(false) { - other.canceled = true; + Deferred(Deferred&&) = default; + // Since we use a kj::Maybe, the default move constructor does exactly what we want it to do. + + void run() { + // Move `maybeFunc` to the local scope so that even if we throw, we destroy the functor we had. + auto maybeLocalFunc = kj::mv(maybeFunc); + KJ_IF_MAYBE(func, maybeLocalFunc) { + (*func)(); + } + } + + void cancel() { + maybeFunc = nullptr; } + private: - Func func; - bool canceled; + kj::Maybe maybeFunc; + // Note that `Func` may actually be an lvalue reference because `kj::defer` takes its argument via + // universal reference. `kj::Maybe` has specializations for lvalue reference types, so this works + // out. }; } // namespace _ (private) template _::Deferred defer(Func&& func) { - // Returns an object which will invoke the given functor in its destructor. The object is not - // copyable but is movable with the semantics you'd expect. Since the return type is private, - // you need to assign to an `auto` variable. + // Returns an object which will invoke the given functor in its destructor. The object is not + // copyable but is move-constructable with the semantics you'd expect. Since the return type is + // private, you need to assign to an `auto` variable. // // The KJ_DEFER macro provides slightly more convenient syntax for the common case where you // want some code to run at current scope exit. + // + // KJ_DEFER does not support move-assignment for its returned objects. If you need to reuse the + // variable for your deferred function object, then you will want to write your own class for that + // purpose. return _::Deferred(kj::fwd(func)); } diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/BUILD.bazel b/libs/EXTERNAL/capnproto/c++/src/kj/compat/BUILD.bazel new file mode 100644 index 00000000000..f9e31ce6be7 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/BUILD.bazel @@ -0,0 +1,155 @@ +exports_files(["gtest.h"]) + +cc_library( + name = "kj-tls", + srcs = [ + "readiness-io.c++", + "tls.c++", + ], + hdrs = [ + "readiness-io.h", + "tls.h", + ], + include_prefix = "kj/compat", + target_compatible_with = select({ + "//src/kj:use_openssl": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + visibility = ["//visibility:public"], + deps = [ + "//src/kj:kj-async", + "@ssl", + ], +) + +cc_library( + name = "kj-http", + srcs = [ + "http.c++", + "url.c++", + ], + hdrs = [ + "http.h", + "url.h", + ], + include_prefix = "kj/compat", + visibility = ["//visibility:public"], + deps = [ + "//src/kj:kj-async", + "@zlib", + ], +) + +cc_library( + name = "kj-gzip", + srcs = ["gzip.c++"], + hdrs = ["gzip.h"], + include_prefix = "kj/compat", + visibility = ["//visibility:public"], + deps = [ + "//src/kj:kj-async", + "@zlib", + ], +) + +cc_library( + name = "kj-brotli", + srcs = ["brotli.c++"], + hdrs = ["brotli.h"], + include_prefix = "kj/compat", + visibility = ["//visibility:public"], + target_compatible_with = select({ + "//src/kj:use_brotli": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + "//src/kj:kj-async", + "@brotli//:brotlienc", + "@brotli//:brotlidec", + ], +) + +cc_library( + name = "gtest", + hdrs = ["gtest.h"], + include_prefix = "kj/compat", + visibility = ["//visibility:public"], + deps = ["//src/kj"], +) + +kj_tests = [ + "http-test.c++", + "url-test.c++", +] + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + deps = [ + ":kj-http", + "//src/kj:kj-test", + ], +) for f in kj_tests] + +cc_library( + name = "http-socketpair-test-base", + hdrs = ["http-test.c++"], +) + +cc_test( + name = "http-socketpair-test", + srcs = ["http-socketpair-test.c++"], + deps = [ + ":http-socketpair-test-base", + ":kj-http", + "//src/kj:kj-test", + ], + target_compatible_with = [ + "@platforms//os:linux", # TODO: Investigate why this fails on macOS + ], +) + +kj_tls_tests = [ + "tls-test.c++", + "readiness-io-test.c++", +] + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + target_compatible_with = select({ + "//src/kj:use_openssl": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kj-tls", + ":kj-http", + "//src/kj:kj-test", + ], +) for f in kj_tls_tests] + +cc_test( + name = "gzip-test", + srcs = ["gzip-test.c++"], + target_compatible_with = select({ + "//src/kj:use_zlib": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kj-gzip", + "//src/kj:kj-test", + ], +) + +cc_test( + name = "brotli-test", + srcs = ["brotli-test.c++"], + target_compatible_with = select({ + "//src/kj:use_brotli": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kj-brotli", + "//src/kj:kj-test", + ], +) diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli-test.c++ new file mode 100644 index 00000000000..f0e00d1e084 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli-test.c++ @@ -0,0 +1,410 @@ +// Copyright (c) 2023 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if KJ_HAS_BROTLI + +#include "brotli.h" +#include +#include +#include + +namespace kj { +namespace { + +static const byte FOOBAR_BR[] = { + 0x83, 0x02, 0x80, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72, 0x03, +}; + +// brotli stream with 24 window bits, i.e. the max window size. If KJ_BROTLI_MAX_DEC_WBITS is less +// than 24, the stream will be rejected by default. This approach should be acceptable in a web +// context, where few files benefit from larger windows and memory usage matters for +// concurrent transfers. +static const byte FOOBAR_BR_LARGE_WIN[] = { + 0x8f, 0x02, 0x80, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72, 0x03, +}; + +class MockInputStream: public InputStream { +public: + MockInputStream(kj::ArrayPtr bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr bytes; + size_t blockSize; +}; + +class MockAsyncInputStream: public AsyncInputStream { +public: + MockAsyncInputStream(kj::ArrayPtr bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr bytes; + size_t blockSize; +}; + +class MockOutputStream: public OutputStream { +public: + kj::Vector bytes; + + kj::String decompress() { + MockInputStream rawInput(bytes, kj::maxValue); + BrotliInputStream brotli(rawInput); + return brotli.readAllText(); + } + + void write(const void* buffer, size_t size) override { + bytes.addAll(arrayPtr(reinterpret_cast(buffer), size)); + } + void write(ArrayPtr> pieces) override { + for (auto& piece: pieces) { + bytes.addAll(piece); + } + } +}; + +class MockAsyncOutputStream: public AsyncOutputStream { +public: + kj::Vector bytes; + + kj::String decompress(WaitScope& ws) { + MockAsyncInputStream rawInput(bytes, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput); + return brotli.readAllText().wait(ws); + } + + Promise write(const void* buffer, size_t size) override { + bytes.addAll(arrayPtr(reinterpret_cast(buffer), size)); + return kj::READY_NOW; + } + Promise write(ArrayPtr> pieces) override { + for (auto& piece: pieces) { + bytes.addAll(piece); + } + return kj::READY_NOW; + } + + Promise whenWriteDisconnected() override { KJ_UNIMPLEMENTED("not used"); } +}; + +KJ_TEST("brotli decompression") { + // Normal read. + { + MockInputStream rawInput(FOOBAR_BR, kj::maxValue); + BrotliInputStream brotli(rawInput); + KJ_EXPECT(brotli.readAllText() == "foobar"); + } + + // Force read one byte at a time. + { + MockInputStream rawInput(FOOBAR_BR, 1); + BrotliInputStream brotli(rawInput); + KJ_EXPECT(brotli.readAllText() == "foobar"); + } + + // Read truncated input. + { + MockInputStream rawInput(kj::arrayPtr(FOOBAR_BR, sizeof(FOOBAR_BR) / 2), kj::maxValue); + BrotliInputStream brotli(rawInput); + + char text[16]; + size_t n = brotli.tryRead(text, 1, sizeof(text)); + text[n] = '\0'; + KJ_EXPECT(StringPtr(text, n) == "fo"); + + KJ_EXPECT_THROW_MESSAGE("brotli compressed stream ended prematurely", + brotli.tryRead(text, 1, sizeof(text))); + } + + // Check that stream with high window size is rejected. Conversely, check that it is accepted if + // configured to accept the full window size. + { + MockInputStream rawInput(FOOBAR_BR_LARGE_WIN, kj::maxValue); + BrotliInputStream brotli(rawInput, BROTLI_DEFAULT_WINDOW); + KJ_EXPECT_THROW_MESSAGE("brotli window size too big", brotli.readAllText()); + } + + { + MockInputStream rawInput(FOOBAR_BR_LARGE_WIN, kj::maxValue); + BrotliInputStream brotli(rawInput, BROTLI_MAX_WINDOW_BITS); + KJ_EXPECT(brotli.readAllText() == "foobar"); + } + + // Check that invalid stream is rejected. + { + MockInputStream rawInput(kj::arrayPtr(FOOBAR_BR + 3, sizeof(FOOBAR_BR) - 3), kj::maxValue); + BrotliInputStream brotli(rawInput); + KJ_EXPECT_THROW_MESSAGE("brotli decompression failed", brotli.readAllText()); + } + + // Read concatenated input. + { + Vector bytes; + bytes.addAll(ArrayPtr(FOOBAR_BR)); + bytes.addAll(ArrayPtr(FOOBAR_BR)); + MockInputStream rawInput(bytes, kj::maxValue); + BrotliInputStream brotli(rawInput); + + KJ_EXPECT(brotli.readAllText() == "foobarfoobar"); + } +} + +KJ_TEST("async brotli decompression") { + auto io = setupAsyncIo(); + + // Normal read. + { + MockAsyncInputStream rawInput(FOOBAR_BR, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput); + KJ_EXPECT(brotli.readAllText().wait(io.waitScope) == "foobar"); + } + + // Force read one byte at a time. + { + MockAsyncInputStream rawInput(FOOBAR_BR, 1); + BrotliAsyncInputStream brotli(rawInput); + KJ_EXPECT(brotli.readAllText().wait(io.waitScope) == "foobar"); + } + + // Read truncated input. + { + MockAsyncInputStream rawInput(kj::arrayPtr(FOOBAR_BR, sizeof(FOOBAR_BR) / 2), kj::maxValue); + BrotliAsyncInputStream brotli(rawInput); + + char text[16]; + size_t n = brotli.tryRead(text, 1, sizeof(text)).wait(io.waitScope); + text[n] = '\0'; + KJ_EXPECT(StringPtr(text, n) == "fo"); + + KJ_EXPECT_THROW_MESSAGE("brotli compressed stream ended prematurely", + brotli.tryRead(text, 1, sizeof(text)).wait(io.waitScope)); + } + + // Check that stream with high window size is rejected. Conversely, check that it is accepted if + // configured to accept the full window size. + { + MockAsyncInputStream rawInput(FOOBAR_BR_LARGE_WIN, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput, BROTLI_DEFAULT_WINDOW); + KJ_EXPECT_THROW_MESSAGE("brotli window size too big", + brotli.readAllText().wait(io.waitScope)); + } + + { + MockAsyncInputStream rawInput(FOOBAR_BR_LARGE_WIN, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput, BROTLI_MAX_WINDOW_BITS); + KJ_EXPECT(brotli.readAllText().wait(io.waitScope) == "foobar"); + } + + // Read concatenated input. + { + Vector bytes; + bytes.addAll(ArrayPtr(FOOBAR_BR)); + bytes.addAll(ArrayPtr(FOOBAR_BR)); + MockAsyncInputStream rawInput(bytes, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput); + + KJ_EXPECT(brotli.readAllText().wait(io.waitScope) == "foobarfoobar"); + } + + // Decompress using an output stream. + { + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotli(rawOutput, BrotliAsyncOutputStream::DECOMPRESS); + + auto mid = sizeof(FOOBAR_BR) / 2; + brotli.write(FOOBAR_BR, mid).wait(io.waitScope); + auto str1 = kj::heapString(rawOutput.bytes.asPtr().asChars()); + KJ_EXPECT(str1 == "fo", str1); + + brotli.write(FOOBAR_BR + mid, sizeof(FOOBAR_BR) - mid).wait(io.waitScope); + auto str2 = kj::heapString(rawOutput.bytes.asPtr().asChars()); + KJ_EXPECT(str2 == "foobar", str2); + + brotli.end().wait(io.waitScope); + } +} + +KJ_TEST("brotli compression") { + // Normal write. + { + MockOutputStream rawOutput; + { + BrotliOutputStream brotli(rawOutput); + brotli.write("foobar", 6); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } + + // Multi-part write. + { + MockOutputStream rawOutput; + { + BrotliOutputStream brotli(rawOutput); + brotli.write("foo", 3); + brotli.write("bar", 3); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } + + // Array-of-arrays write. + { + MockOutputStream rawOutput; + + { + BrotliOutputStream brotli(rawOutput); + + ArrayPtr pieces[] = { + kj::StringPtr("foo").asBytes(), + kj::StringPtr("bar").asBytes(), + }; + brotli.write(pieces); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } +} + +KJ_TEST("brotli huge round trip") { + auto bytes = heapArray(96*1024); + for (auto& b: bytes) { + b = rand(); + } + + MockOutputStream rawOutput; + { + BrotliOutputStream brotliOut(rawOutput); + brotliOut.write(bytes.begin(), bytes.size()); + } + + MockInputStream rawInput(rawOutput.bytes, kj::maxValue); + BrotliInputStream brotliIn(rawInput); + auto decompressed = brotliIn.readAllBytes(); + + KJ_ASSERT(decompressed.size() == bytes.size()); + KJ_ASSERT(memcmp(bytes.begin(), decompressed.begin(), bytes.size()) == 0); +} + +KJ_TEST("async brotli compression") { + auto io = setupAsyncIo(); + // Normal write. + { + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotli(rawOutput); + brotli.write("foobar", 6).wait(io.waitScope); + brotli.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } + + // Multi-part write. + { + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotli(rawOutput); + + brotli.write("foo", 3).wait(io.waitScope); + auto prevSize = rawOutput.bytes.size(); + + brotli.write("bar", 3).wait(io.waitScope); + auto curSize = rawOutput.bytes.size(); + KJ_EXPECT(prevSize == curSize, prevSize, curSize); + + brotli.flush().wait(io.waitScope); + curSize = rawOutput.bytes.size(); + KJ_EXPECT(prevSize < curSize, prevSize, curSize); + + brotli.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } + + // Array-of-arrays write. + { + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotli(rawOutput); + + ArrayPtr pieces[] = { + kj::StringPtr("foo").asBytes(), + kj::StringPtr("bar").asBytes(), + }; + brotli.write(pieces).wait(io.waitScope); + brotli.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } +} + +KJ_TEST("async brotli huge round trip") { + auto io = setupAsyncIo(); + + auto bytes = heapArray(65536); + for (auto& b: bytes) { + b = rand(); + } + + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotliOut(rawOutput); + brotliOut.write(bytes.begin(), bytes.size()).wait(io.waitScope); + brotliOut.end().wait(io.waitScope); + + MockAsyncInputStream rawInput(rawOutput.bytes, kj::maxValue); + BrotliAsyncInputStream brotliIn(rawInput); + auto decompressed = brotliIn.readAllBytes().wait(io.waitScope); + + KJ_ASSERT(decompressed.size() == bytes.size()); + KJ_ASSERT(memcmp(bytes.begin(), decompressed.begin(), bytes.size()) == 0); +} + +} // namespace +} // namespace kj + +#endif // KJ_HAS_BROTLI diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli.c++ new file mode 100644 index 00000000000..08efc8abfa7 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli.c++ @@ -0,0 +1,369 @@ +// Copyright (c) 2023 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if KJ_HAS_BROTLI + +#include "brotli.h" +#include + +namespace kj { + +namespace { + +int getBrotliWindowBits(kj::byte peek) { + // Check number of window bits used by the stream, see RFC 7932 + // (https://www.rfc-editor.org/rfc/rfc7932.html#section-9.1) for the specification. + // Adapted from an internal Cloudflare codebase. + if ((peek & 0x01) == 0) { + return 16; + } + + if (((peek >> 1) & 0x07) != 0) { + return 17 + (peek >> 1 & 0x07); + } + + if (((peek >> 4) & 0x07) == 0) { + return 17; + } + + if (((peek >> 4) & 0x07) == 1) { + // Large window brotli, not part of RFC 7932 and not supported in web contexts + return BROTLI_MAX_WINDOW_BITS + 1; + } + + return 8 + ((peek >> 4) & 0x07); +} + +} // namespace + +namespace _ { // private + +BrotliOutputContext::BrotliOutputContext(kj::Maybe compressionLevel, + kj::Maybe windowBitsParam) + : nextIn(nullptr), availableIn(0) { + KJ_IF_MAYBE(level, compressionLevel) { + // Emulate zlib's behavior of using -1 to signify the default quality + if (*level == -1) {*level = KJ_BROTLI_DEFAULT_QUALITY;} + KJ_REQUIRE(*level >= BROTLI_MIN_QUALITY && *level <= BROTLI_MAX_QUALITY, + "invalid brotli compression level", *level); + windowBits = windowBitsParam.orDefault(_::KJ_BROTLI_DEFAULT_WBITS); + KJ_REQUIRE(windowBits >= BROTLI_MIN_WINDOW_BITS && windowBits <= BROTLI_MAX_WINDOW_BITS, + "invalid brotli window size", windowBits); + BrotliEncoderState* cctx = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(cctx, "brotli state allocation failed"); + KJ_ASSERT(BrotliEncoderSetParameter(cctx, BROTLI_PARAM_QUALITY, *level) == BROTLI_TRUE); + KJ_ASSERT(BrotliEncoderSetParameter(cctx, BROTLI_PARAM_LGWIN, windowBits) == BROTLI_TRUE); + ctx = cctx; + } else { + // In the decoder, we manually check that the stream does not have a higher window size than + // requested and reject it otherwise, no way to automate this step. + // By default, we accept streams with a window size up to (1 << KJ_BROTLI_MAX_DEC_WBITS), + // this is more than the default window size for compression (i.e. KJ_BROTLI_DEFAULT_WBITS) + windowBits = windowBitsParam.orDefault(_::KJ_BROTLI_MAX_DEC_WBITS); + KJ_REQUIRE(windowBits >= BROTLI_MIN_WINDOW_BITS && windowBits <= BROTLI_MAX_WINDOW_BITS, + "invalid brotli window size", windowBits); + BrotliDecoderState* dctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(dctx, "brotli state allocation failed"); + ctx = dctx; + } +} + +BrotliOutputContext::~BrotliOutputContext() noexcept(false) { + KJ_SWITCH_ONEOF(ctx) { + KJ_CASE_ONEOF(cctx, BrotliEncoderState*) { + BrotliEncoderDestroyInstance(cctx); + } + KJ_CASE_ONEOF(dctx, BrotliDecoderState*) { + BrotliDecoderDestroyInstance(dctx); + } + } +} + +void BrotliOutputContext::setInput(const void* in, size_t size) { + nextIn = reinterpret_cast(in); + availableIn = size; +} + +kj::Tuple> BrotliOutputContext::pumpOnce( + BrotliEncoderOperation flush) { + byte* nextOut = buffer; + size_t availableOut = sizeof(buffer); + // Brotli does not accept a null input pointer; make sure there is a valid pointer even if we are + // not actually reading from it. + if (!nextIn) { + KJ_ASSERT(availableIn == 0); + nextIn = buffer; + } + + KJ_SWITCH_ONEOF(ctx) { + KJ_CASE_ONEOF(dctx, BrotliDecoderState*) { + // Check window bits + if (firstInput && availableIn) { + firstInput = false; + int streamWbits = getBrotliWindowBits(nextIn[0]); + KJ_REQUIRE(streamWbits <= windowBits, "brotli window size too big", (1 << streamWbits)); + } + BrotliDecoderResult result = BrotliDecoderDecompressStream(dctx, &availableIn, &nextIn, + &availableOut, &nextOut, nullptr); + if (result == BROTLI_DECODER_RESULT_ERROR) { + // Note: Unlike BrotliInputStream, this will implicitly reject trailing data during + // decompression, matching the behavior for gzip. + KJ_FAIL_REQUIRE("brotli decompression failed", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(dctx))); + } + // The 'ok' parameter represented by the first parameter of the tuple indicates that + // pumpOnce() should be called again as more output data can be produced. This is the case + // when the stream is not finished and there is either pending output data (that didn't fit + // into the buffer) or input that has not been processed yet. + return kj::tuple(BrotliDecoderHasMoreOutput(dctx), + kj::arrayPtr(buffer, sizeof(buffer) - availableOut)); + } + KJ_CASE_ONEOF(cctx, BrotliEncoderState*) { + BROTLI_BOOL result = BrotliEncoderCompressStream(cctx, flush, &availableIn, &nextIn, + &availableOut, &nextOut, nullptr); + KJ_REQUIRE(result == BROTLI_TRUE, "brotli compression failed"); + + return kj::tuple(BrotliEncoderHasMoreOutput(cctx), + kj::arrayPtr(buffer, sizeof(buffer) - availableOut)); + } + } + KJ_UNREACHABLE; +} + +} // namespace _ (private) + +// ======================================================================================= + +BrotliInputStream::BrotliInputStream(InputStream& inner, kj::Maybe windowBitsParam) + : inner(inner), windowBits(windowBitsParam.orDefault(_::KJ_BROTLI_MAX_DEC_WBITS)), + nextIn(nullptr), availableIn(0) { + KJ_REQUIRE(windowBits >= BROTLI_MIN_WINDOW_BITS && windowBits <= BROTLI_MAX_WINDOW_BITS, + "invalid brotli window size", windowBits); + ctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(ctx, "brotli state allocation failed"); +} + +BrotliInputStream::~BrotliInputStream() noexcept(false) { + BrotliDecoderDestroyInstance(ctx); +} + +size_t BrotliInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) { + if (maxBytes == 0) return size_t(0); + + return readImpl(reinterpret_cast(out), minBytes, maxBytes, 0); +} + +size_t BrotliInputStream::readImpl( + byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) { + // Ask for more input unless there is pending output + if (availableIn == 0 && !BrotliDecoderHasMoreOutput(ctx)) { + size_t amount = inner.tryRead(buffer, 1, sizeof(buffer)); + if (amount == 0) { + KJ_REQUIRE(atValidEndpoint, "brotli compressed stream ended prematurely"); + return alreadyRead; + } else { + nextIn = buffer; + availableIn = amount; + } + } + + byte* nextOut = out; + size_t availableOut = maxBytes; + // Check window bits + if (firstInput && availableIn) { + firstInput = false; + int streamWbits = getBrotliWindowBits(nextIn[0]); + KJ_REQUIRE(streamWbits <= windowBits, + "brotli window size too big", (1 << streamWbits)); + } + BrotliDecoderResult result = BrotliDecoderDecompressStream( + ctx, &availableIn, &nextIn, &availableOut, &nextOut, nullptr); + KJ_REQUIRE(result != BROTLI_DECODER_RESULT_ERROR, "brotli decompression failed", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(ctx))); + + atValidEndpoint = result == BROTLI_DECODER_RESULT_SUCCESS; + if (atValidEndpoint && availableIn > 0) { + // There's more data available. Assume start of new content. + // Not sure if we actually want this, but there is limited potential for breakage as arbitrary + // trailing data should still be rejected. Unfortunately this is kind of clunky as brotli does + // not support resetting an instance. + BrotliDecoderDestroyInstance(ctx); + ctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(ctx, "brotli state allocation failed"); + firstInput = true; + } + + size_t n = maxBytes - availableOut; + if (n >= minBytes) { + return n + alreadyRead; + } else { + return readImpl(out + n, minBytes - n, maxBytes - n, alreadyRead + n); + } +} + +BrotliOutputStream::BrotliOutputStream(OutputStream& inner, int compressionLevel, int windowBits) + : inner(inner), ctx(compressionLevel, windowBits) {} + +BrotliOutputStream::BrotliOutputStream(OutputStream& inner, decltype(DECOMPRESS), int windowBits) + : inner(inner), ctx(nullptr, windowBits) {} + +BrotliOutputStream::~BrotliOutputStream() noexcept(false) { + pump(BROTLI_OPERATION_FINISH); +} + +void BrotliOutputStream::write(const void* in, size_t size) { + ctx.setInput(in, size); + pump(BROTLI_OPERATION_PROCESS); +} + +void BrotliOutputStream::pump(BrotliEncoderOperation flush) { + bool ok; + do { + auto result = ctx.pumpOnce(flush); + ok = get<0>(result); + auto chunk = get<1>(result); + if (chunk.size() > 0) { + inner.write(chunk.begin(), chunk.size()); + } + } while (ok); +} + +// ======================================================================================= + +BrotliAsyncInputStream::BrotliAsyncInputStream(AsyncInputStream& inner, + kj::Maybe windowBitsParam) + : inner(inner), windowBits(windowBitsParam.orDefault(_::KJ_BROTLI_MAX_DEC_WBITS)), + nextIn(nullptr), availableIn(0) { + KJ_REQUIRE(windowBits >= BROTLI_MIN_WINDOW_BITS && windowBits <= BROTLI_MAX_WINDOW_BITS, + "invalid brotli window size", windowBits); + ctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(ctx, "brotli state allocation failed"); +} + +BrotliAsyncInputStream::~BrotliAsyncInputStream() noexcept(false) { + BrotliDecoderDestroyInstance(ctx); +} + +Promise BrotliAsyncInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) { + if (maxBytes == 0) return constPromise(); + + return readImpl(reinterpret_cast(out), minBytes, maxBytes, 0); +} + +Promise BrotliAsyncInputStream::readImpl( + byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) { + // Ask for more input unless there is pending output + if (availableIn == 0 && !BrotliDecoderHasMoreOutput(ctx)) { + return inner.tryRead(buffer, 1, sizeof(buffer)) + .then([this,out,minBytes,maxBytes,alreadyRead](size_t amount) -> Promise { + if (amount == 0) { + if (!atValidEndpoint) { + return KJ_EXCEPTION(DISCONNECTED, "brotli compressed stream ended prematurely"); + } + return alreadyRead; + } else { + nextIn = buffer; + availableIn = amount; + return readImpl(out, minBytes, maxBytes, alreadyRead); + } + }); + } + + byte* nextOut = out; + size_t availableOut = maxBytes; + // Check window bits + if (firstInput && availableIn) { + firstInput = false; + int streamWbits = getBrotliWindowBits(nextIn[0]); + KJ_REQUIRE(streamWbits <= windowBits, + "brotli window size too big", (1 << streamWbits)); + } + BrotliDecoderResult result = BrotliDecoderDecompressStream( + ctx, &availableIn, &nextIn, &availableOut, &nextOut, nullptr); + KJ_REQUIRE(result != BROTLI_DECODER_RESULT_ERROR, "brotli decompression failed", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(ctx))); + + atValidEndpoint = result == BROTLI_DECODER_RESULT_SUCCESS; + if (atValidEndpoint && availableIn > 0) { + // There's more data available. Assume start of new content. + // Not sure if we actually want this, but there is limited potential for breakage as arbitrary + // trailing data should still be rejected. Unfortunately this is kind of clunky as brotli does + // not support resetting an instance. + BrotliDecoderDestroyInstance(ctx); + ctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(ctx, "brotli state allocation failed"); + firstInput = true; + } + + size_t n = maxBytes - availableOut; + if (n >= minBytes) { + return n + alreadyRead; + } else { + return readImpl(out + n, minBytes - n, maxBytes - n, alreadyRead + n); + } +} + +// ======================================================================================= + +BrotliAsyncOutputStream::BrotliAsyncOutputStream(AsyncOutputStream& inner, int compressionLevel, + int windowBits) + : inner(inner), ctx(compressionLevel, windowBits) {} + +BrotliAsyncOutputStream::BrotliAsyncOutputStream(AsyncOutputStream& inner, decltype(DECOMPRESS), + int windowBits) + : inner(inner), ctx(nullptr, windowBits) {} + +Promise BrotliAsyncOutputStream::write(const void* in, size_t size) { + ctx.setInput(in, size); + return pump(BROTLI_OPERATION_PROCESS); +} + +Promise BrotliAsyncOutputStream::write(ArrayPtr> pieces) { + if (pieces.size() == 0) return kj::READY_NOW; + return write(pieces[0].begin(), pieces[0].size()) + .then([this,pieces]() { + return write(pieces.slice(1, pieces.size())); + }); +} + +kj::Promise BrotliAsyncOutputStream::pump(BrotliEncoderOperation flush) { + auto result = ctx.pumpOnce(flush); + auto ok = get<0>(result); + auto chunk = get<1>(result); + + if (chunk.size() == 0) { + if (ok) { + return pump(flush); + } else { + return kj::READY_NOW; + } + } else { + auto promise = inner.write(chunk.begin(), chunk.size()); + if (ok) { + promise = promise.then([this, flush]() { return pump(flush); }); + } + return promise; + } +} + +} // namespace kj + +#endif // KJ_HAS_BROTLI diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli.h b/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli.h new file mode 100644 index 00000000000..3fd2181b5c4 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/brotli.h @@ -0,0 +1,190 @@ +// Copyright (c) 2023 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include +#include +#include +#include + +KJ_BEGIN_HEADER + +namespace kj { + +// level 5 should offer a good default tradeoff based on concerns about being slower than gzip at +// e.g. level 6 and about compressing worse than gzip at lower levels. Note that +// BROTLI_DEFAULT_QUALITY is set to the maximum level of 11 – way too slow for on-the-fly +// compression. +constexpr size_t KJ_BROTLI_DEFAULT_QUALITY = 5; + +namespace _ { // private +// Use a window size of (1 << 19) = 512K by default. Higher values improve compression on longer +// streams but increase memory usage. +constexpr size_t KJ_BROTLI_DEFAULT_WBITS = 19; + +// Maximum window size for streams to be decompressed, streams with larger windows are rejected. +// This is currently set to the maximum window size of 16MB, so all RFC 7932-compliant brotli +// streams will be accepted. For applications where memory usage is a concern, using +// BROTLI_DEFAULT_WINDOW (equivalent to 4MB window) is recommended instead as larger window sizes +// are rarely useful in a web context. +constexpr size_t KJ_BROTLI_MAX_DEC_WBITS = BROTLI_MAX_WINDOW_BITS; + +// Use an output buffer size of 8K, larger sizes did not seem to significantly improve performance, +// perhaps due to brotli's internal output buffer. +constexpr size_t KJ_BROTLI_BUF_SIZE = 8192; + +class BrotliOutputContext final { +public: + BrotliOutputContext(kj::Maybe compressionLevel, kj::Maybe windowBits = nullptr); + ~BrotliOutputContext() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(BrotliOutputContext); + + void setInput(const void* in, size_t size); + kj::Tuple> pumpOnce(BrotliEncoderOperation flush); + // Flush the stream. Parameter is ignored for decoding as brotli only uses an operation parameter + // during encoding. + +private: + int windowBits; + const byte* nextIn; + size_t availableIn; + bool firstInput = true; + + kj::OneOf ctx; + byte buffer[_::KJ_BROTLI_BUF_SIZE]; +}; + +} // namespace _ (private) + +class BrotliInputStream final: public InputStream { +public: + BrotliInputStream(InputStream& inner, kj::Maybe windowBits = nullptr); + ~BrotliInputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(BrotliInputStream); + + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + +private: + InputStream& inner; + BrotliDecoderState* ctx; + int windowBits; + bool atValidEndpoint = false; + + byte buffer[_::KJ_BROTLI_BUF_SIZE]; + + const byte* nextIn; + size_t availableIn; + bool firstInput = true; + + size_t readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); +}; + +class BrotliOutputStream final: public OutputStream { +public: + enum { DECOMPRESS }; + + // Order of arguments is not ideal, but allows us to specify the window size if needed while + // remaining compatible with the gzip API. + BrotliOutputStream(OutputStream& inner, int compressionLevel = KJ_BROTLI_DEFAULT_QUALITY, + int windowBits = _::KJ_BROTLI_DEFAULT_WBITS); + BrotliOutputStream(OutputStream& inner, decltype(DECOMPRESS), + int windowBits = _::KJ_BROTLI_MAX_DEC_WBITS); + ~BrotliOutputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(BrotliOutputStream); + + void write(const void* buffer, size_t size) override; + using OutputStream::write; + + inline void flush() { + // brotli decoder does not use this parameter, but automatically flushes as much as it can. + pump(BROTLI_OPERATION_FLUSH); + } + +private: + OutputStream& inner; + _::BrotliOutputContext ctx; + + void pump(BrotliEncoderOperation flush); +}; + +class BrotliAsyncInputStream final: public AsyncInputStream { +public: + BrotliAsyncInputStream(AsyncInputStream& inner, kj::Maybe windowBits = nullptr); + ~BrotliAsyncInputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(BrotliAsyncInputStream); + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + +private: + AsyncInputStream& inner; + BrotliDecoderState* ctx; + int windowBits; + bool atValidEndpoint = false; + + byte buffer[_::KJ_BROTLI_BUF_SIZE]; + const byte* nextIn; + size_t availableIn; + bool firstInput = true; + + Promise readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); +}; + +class BrotliAsyncOutputStream final: public AsyncOutputStream { +public: + enum { DECOMPRESS }; + + BrotliAsyncOutputStream(AsyncOutputStream& inner, + int compressionLevel = KJ_BROTLI_DEFAULT_QUALITY, + int windowBits = _::KJ_BROTLI_DEFAULT_WBITS); + BrotliAsyncOutputStream(AsyncOutputStream& inner, decltype(DECOMPRESS), + int windowBits = _::KJ_BROTLI_MAX_DEC_WBITS); + KJ_DISALLOW_COPY_AND_MOVE(BrotliAsyncOutputStream); + + Promise write(const void* buffer, size_t size) override; + Promise write(ArrayPtr> pieces) override; + + Promise whenWriteDisconnected() override { return inner.whenWriteDisconnected(); } + + inline Promise flush() { + // brotli decoder does not use this parameter, but automatically flushes as much as it can. + return pump(BROTLI_OPERATION_FLUSH); + } + // Call if you need to flush a stream at an arbitrary data point. + + Promise end() { + return pump(BROTLI_OPERATION_FINISH); + } + // Must call to flush and finish the stream, since some data may be buffered. + // + // TODO(cleanup): This should be a virtual method on AsyncOutputStream. + +private: + AsyncOutputStream& inner; + _::BrotliOutputContext ctx; + + kj::Promise pump(BrotliEncoderOperation flush); +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/gtest.h b/libs/EXTERNAL/capnproto/c++/src/kj/compat/gtest.h index 0d7d361de1c..4db0535c351 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/gtest.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/gtest.h @@ -29,9 +29,11 @@ // - Test fixtures are not supported. Allocate your "test fixture" on the stack instead. Do setup // in the constructor, teardown in the destructor. -#include "../test.h" +#include #include // work-around macro conflict with `ERROR` +KJ_BEGIN_HEADER + namespace kj { namespace _ { // private @@ -118,3 +120,5 @@ class AddFailureAdapter { #define TEST(x, y) KJ_TEST("legacy test: " #x "/" #y) } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/gzip.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/gzip.c++ index 60d5a8f09d2..a36cde774fa 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/gzip.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/gzip.c++ @@ -103,6 +103,12 @@ size_t GzipInputStream::readImpl( byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) { if (ctx.avail_in == 0) { size_t amount = inner.tryRead(buffer, 1, sizeof(buffer)); + // Note: This check would reject valid streams with a high compression ratio if zlib were to + // read in the entire input data, getting more decompressed data than fits in the out buffer + // and subsequently fill the output buffer and internally store some pending data. It turns + // out that zlib does not maintain pending output during decompression and this is not + // possible, but this may be a concern when implementing support for other algorithms as e.g. + // brotli's reference implementation maintains a decompression output buffer. if (amount == 0) { if (!atValidEndpoint) { KJ_FAIL_REQUIRE("gzip compressed stream ended prematurely"); @@ -114,7 +120,7 @@ size_t GzipInputStream::readImpl( } } - ctx.next_out = reinterpret_cast(out); + ctx.next_out = out; ctx.avail_out = maxBytes; auto inflateResult = inflate(&ctx, Z_NO_FLUSH); @@ -182,7 +188,7 @@ GzipAsyncInputStream::~GzipAsyncInputStream() noexcept(false) { } Promise GzipAsyncInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) { - if (maxBytes == 0) return size_t(0); + if (maxBytes == 0) return constPromise(); return readImpl(reinterpret_cast(out), minBytes, maxBytes, 0); } @@ -205,7 +211,7 @@ Promise GzipAsyncInputStream::readImpl( }); } - ctx.next_out = reinterpret_cast(out); + ctx.next_out = out; ctx.avail_out = maxBytes; auto inflateResult = inflate(&ctx, Z_NO_FLUSH); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/gzip.h b/libs/EXTERNAL/capnproto/c++/src/kj/compat/gzip.h index 5045e11757d..37b4961fed5 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/gzip.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/gzip.h @@ -25,15 +25,19 @@ #include #include +KJ_BEGIN_HEADER + namespace kj { namespace _ { // private +constexpr size_t KJ_GZ_BUF_SIZE = 4096; + class GzipOutputContext final { public: GzipOutputContext(kj::Maybe compressionLevel); ~GzipOutputContext() noexcept(false); - KJ_DISALLOW_COPY(GzipOutputContext); + KJ_DISALLOW_COPY_AND_MOVE(GzipOutputContext); void setInput(const void* in, size_t size); kj::Tuple> pumpOnce(int flush); @@ -41,7 +45,7 @@ class GzipOutputContext final { private: bool compressing; z_stream ctx = {}; - byte buffer[4096]; + byte buffer[_::KJ_GZ_BUF_SIZE]; [[noreturn]] void fail(int result); }; @@ -52,7 +56,7 @@ class GzipInputStream final: public InputStream { public: GzipInputStream(InputStream& inner); ~GzipInputStream() noexcept(false); - KJ_DISALLOW_COPY(GzipInputStream); + KJ_DISALLOW_COPY_AND_MOVE(GzipInputStream); size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; @@ -61,7 +65,7 @@ class GzipInputStream final: public InputStream { z_stream ctx = {}; bool atValidEndpoint = false; - byte buffer[4096]; + byte buffer[_::KJ_GZ_BUF_SIZE]; size_t readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); }; @@ -73,7 +77,7 @@ class GzipOutputStream final: public OutputStream { GzipOutputStream(OutputStream& inner, int compressionLevel = Z_DEFAULT_COMPRESSION); GzipOutputStream(OutputStream& inner, decltype(DECOMPRESS)); ~GzipOutputStream() noexcept(false); - KJ_DISALLOW_COPY(GzipOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(GzipOutputStream); void write(const void* buffer, size_t size) override; using OutputStream::write; @@ -93,7 +97,7 @@ class GzipAsyncInputStream final: public AsyncInputStream { public: GzipAsyncInputStream(AsyncInputStream& inner); ~GzipAsyncInputStream() noexcept(false); - KJ_DISALLOW_COPY(GzipAsyncInputStream); + KJ_DISALLOW_COPY_AND_MOVE(GzipAsyncInputStream); Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; @@ -102,7 +106,7 @@ class GzipAsyncInputStream final: public AsyncInputStream { z_stream ctx = {}; bool atValidEndpoint = false; - byte buffer[4096]; + byte buffer[_::KJ_GZ_BUF_SIZE]; Promise readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); }; @@ -113,7 +117,7 @@ class GzipAsyncOutputStream final: public AsyncOutputStream { GzipAsyncOutputStream(AsyncOutputStream& inner, int compressionLevel = Z_DEFAULT_COMPRESSION); GzipAsyncOutputStream(AsyncOutputStream& inner, decltype(DECOMPRESS)); - KJ_DISALLOW_COPY(GzipAsyncOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(GzipAsyncOutputStream); Promise write(const void* buffer, size_t size) override; Promise write(ArrayPtr> pieces) override; @@ -140,3 +144,5 @@ class GzipAsyncOutputStream final: public AsyncOutputStream { }; } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/http-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/http-test.c++ index 87395534e2d..f10ff8d1564 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/http-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/http-test.c++ @@ -25,6 +25,7 @@ #include #include #include +#include #include #if KJ_HTTP_TEST_USE_OS_PIPE @@ -58,8 +59,15 @@ namespace { KJ_TEST("HttpMethod parse / stringify") { #define TRY(name) \ KJ_EXPECT(kj::str(HttpMethod::name) == #name); \ - KJ_IF_MAYBE(parsed, tryParseHttpMethod(#name)) { \ - KJ_EXPECT(*parsed == HttpMethod::name); \ + KJ_IF_MAYBE(parsed, tryParseHttpMethodAllowingConnect(#name)) { \ + KJ_SWITCH_ONEOF(*parsed) { \ + KJ_CASE_ONEOF(method, HttpMethod) { \ + KJ_EXPECT(method == HttpMethod::name); \ + } \ + KJ_CASE_ONEOF(method, HttpConnectMethod) { \ + KJ_FAIL_EXPECT("http method parsed as CONNECT", #name); \ + } \ + } \ } else { \ KJ_FAIL_EXPECT("couldn't parse \"" #name "\" as HttpMethod"); \ } @@ -73,6 +81,10 @@ KJ_TEST("HttpMethod parse / stringify") { KJ_EXPECT(tryParseHttpMethod("GE") == nullptr); KJ_EXPECT(tryParseHttpMethod("GET ") == nullptr); KJ_EXPECT(tryParseHttpMethod("get") == nullptr); + + KJ_EXPECT(KJ_ASSERT_NONNULL(tryParseHttpMethodAllowingConnect("CONNECT")) + .is()); + KJ_EXPECT(tryParseHttpMethod("connect") == nullptr); } KJ_TEST("HttpHeaderTable") { @@ -296,6 +308,48 @@ KJ_TEST("HttpHeaders parse invalid") { } } +KJ_TEST("HttpHeaders require valid HttpHeaderTable") { + const auto ERROR_MESSAGE = + "HttpHeaders object was constructed from HttpHeaderTable " + "that wasn't fully built yet at the time of construction"_kj; + + { + // A tabula rasa is valid. + HttpHeaderTable table; + KJ_REQUIRE(table.isReady()); + + HttpHeaders headers(table); + } + + { + // A future table is not valid. + HttpHeaderTable::Builder builder; + + auto& futureTable = builder.getFutureTable(); + KJ_REQUIRE(!futureTable.isReady()); + + auto makeHeadersThenBuild = [&]() { + HttpHeaders headers(futureTable); + auto table = builder.build(); + }; + KJ_EXPECT_THROW_MESSAGE(ERROR_MESSAGE, makeHeadersThenBuild()); + } + + { + // A well built table is valid. + HttpHeaderTable::Builder builder; + + auto& futureTable = builder.getFutureTable(); + KJ_REQUIRE(!futureTable.isReady()); + + auto ownedTable = builder.build(); + KJ_REQUIRE(futureTable.isReady()); + KJ_REQUIRE(ownedTable->isReady()); + + HttpHeaders headers(futureTable); + } +} + KJ_TEST("HttpHeaders validation") { auto table = HttpHeaderTable::Builder().build(); HttpHeaders headers(*table); @@ -482,7 +536,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { auto buffer = kj::heapArray(expected.size()); auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); - return promise.then(kj::mvCapture(buffer, [&in,expected](kj::Array buffer, size_t amount) { + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { if (amount == 0) { KJ_FAIL_ASSERT("expected data never sent", expected); } @@ -493,7 +547,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { } return expectRead(in, expected.slice(amount)); - })); + }); } kj::Promise expectRead(kj::AsyncInputStream& in, kj::ArrayPtr expected) { @@ -502,7 +556,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::ArrayPtr auto buffer = kj::heapArray(expected.size()); auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); - return promise.then(kj::mvCapture(buffer, [&in,expected](kj::Array buffer, size_t amount) { + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { if (amount == 0) { KJ_FAIL_ASSERT("expected data never sent", expected); } @@ -513,7 +567,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::ArrayPtr } return expectRead(in, expected.slice(amount, expected.size())); - })); + }); } kj::Promise expectEnd(kj::AsyncInputStream& in) { @@ -1762,6 +1816,39 @@ KJ_TEST("WebSocket fragmented") { clientTask.wait(waitScope); } +#if KJ_HAS_ZLIB +KJ_TEST("WebSocket compressed fragment") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, CompressionParameters{ + .outboundNoContextTakeover = false, + .inboundNoContextTakeover = false, + .outboundMaxWindowBits=15, + .inboundMaxWindowBits=15, + }); + + // The message is "Hello", sent in two fragments, see the fragmented example at the bottom of: + // https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 + byte COMPRESSED_DATA[] = { + 0x41, 0x03, 0xf2, 0x48, 0xcd, + + 0x80, 0x04, 0xc9, 0xc9, 0x07, 0x00 + }; + + auto clientTask = client->write(COMPRESSED_DATA, sizeof(COMPRESSED_DATA)); + + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + + clientTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + class FakeEntropySource final: public EntropySource { public: void generate(kj::ArrayPtr buffer) override { @@ -1800,6 +1887,148 @@ KJ_TEST("WebSocket masked") { serverTask.wait(waitScope); } +class WebSocketErrorCatcher : public WebSocketErrorHandler { +public: + kj::Vector errors; + + kj::Exception handleWebSocketProtocolError(kj::WebSocket::ProtocolError protocolError) { + errors.add(kj::mv(protocolError)); + return KJ_EXCEPTION(FAILED, protocolError.description); + } +}; + +KJ_TEST("WebSocket unexpected RSV bits") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', + + 0xF0, 0x05, 'w', 'o', 'r', 'l', 'd' // all RSV bits set, plus FIN + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket unexpected continuation frame") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x80, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', // Continuation frame with no start frame, plus FIN + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket missing continuation frame") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', // Start frame + 0x01, 0x06, 'w', 'o', 'r', 'l', 'd', '!', // Another start frame + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket fragmented control frame") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x09, 0x04, 'd', 'a', 't', 'a' // Fragmented ping frame + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket unknown opcode") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x85, 0x04, 'd', 'a', 't', 'a' // 5 is a reserved opcode + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + clientTask.wait(waitScope); +} + KJ_TEST("WebSocket unsolicited pong") { KJ_HTTP_TEST_SETUP_IO; auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; @@ -2038,7 +2267,7 @@ KJ_TEST("WebSocket pump byte counting") { // The pump completes successfully, forwarding the disconnect. pumpTask.wait(waitScope); - // The eventual receiver gets a disconnect execption. + // The eventual receiver gets a disconnect exception. // (Note: We don't use KJ_EXPECT_THROW here because under -fno-exceptions it forks and we lose // state.) receiveTask.then([](auto) { @@ -2106,10 +2335,62 @@ KJ_TEST("WebSocket pump disconnect on receive") { // The pump completes successfully, forwarding the disconnect. pumpTask.wait(waitScope); - // The eventual receiver gets a disconnect execption. + // The eventual receiver gets a disconnect exception. KJ_EXPECT_THROW(DISCONNECTED, receiveTask.wait(waitScope)); } +KJ_TEST("WebSocket abort propagates through pipe") { + // Pumping one end of a WebSocket pipe into another WebSocket which later becomes aborted will + // cancel the pump promise with a DISCONNECTED exception. + + KJ_HTTP_TEST_SETUP_IO; + auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE; + + auto server = newWebSocket(kj::mv(pipe1.ends[1]), nullptr); + auto client = newWebSocket(kj::mv(pipe1.ends[0]), nullptr); + + auto wsPipe = newWebSocketPipe(); + + auto downstreamPump = wsPipe.ends[0]->pumpTo(*server); + KJ_EXPECT(!downstreamPump.poll(waitScope)); + + client->abort(); + + KJ_EXPECT(downstreamPump.poll(waitScope)); + KJ_EXPECT_THROW_RECOVERABLE(DISCONNECTED, downstreamPump.wait(waitScope)); +} + +KJ_TEST("WebSocket maximum message size") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe =KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + FakeEntropySource maskGenerator; + auto client = newWebSocket(kj::mv(pipe.ends[0]), maskGenerator); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + size_t maxSize = 100; + auto biggestAllowedString = kj::strArray(kj::repeat(kj::StringPtr("A"), maxSize), ""); + auto tooBigString = kj::strArray(kj::repeat(kj::StringPtr("B"), maxSize + 1), ""); + + auto clientTask = client->send(biggestAllowedString) + .then([&]() { return client->send(tooBigString); }) + .then([&]() { return client->close(1234, "done"); }); + + { + auto message = server->receive(maxSize).wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().size() == maxSize); + } + + { + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("too large", + server->receive(maxSize).ignoreResult().wait(waitScope)); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1009); + } +} + class TestWebSocketService final: public HttpService, private kj::TaskSet::ErrorHandler { public: TestWebSocketService(HttpHeaderTable& headerTable, HttpHeaderId hMyHeader) @@ -2183,6 +2464,40 @@ const char WEBSOCKET_RESPONSE_HANDSHAKE[] = "Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n" "My-Header: respond-foo\r\n" "\r\n"; +#if KJ_HAS_ZLIB +const char WEBSOCKET_COMPRESSION_HANDSHAKE[] = + " HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n" + "\r\n"; +const char WEBSOCKET_COMPRESSION_RESPONSE_HANDSHAKE[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n" + "\r\n"; +const char WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_HANDSHAKE[] = + " HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; " + "server_no_context_takeover\r\n" + "\r\n"; +const char WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_RESPONSE_HANDSHAKE[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; " + "server_no_context_takeover\r\n" + "\r\n"; +#endif // KJ_HAS_ZLIB const char WEBSOCKET_RESPONSE_HANDSHAKE_ERROR[] = "HTTP/1.1 404 Not Found\r\n" "Content-Length: 0\r\n" @@ -2199,6 +2514,33 @@ const byte WEBSOCKET_SEND_CLOSE[] = const byte WEBSOCKET_REPLY_CLOSE[] = { 0x88, 0x11, 0x12, 0x35, 'c','l','o','s','e','-','r','e','p','l','y',':','q','u','x' }; +#if KJ_HAS_ZLIB +const byte WEBSOCKET_FIRST_COMPRESSED_MESSAGE[] = + { 0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00 }; +// See this example: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.2 +const byte WEBSOCKET_SEND_COMPRESSED_MESSAGE[] = + { 0xc1, 0x87, 12, 34, 56, 78, 0xf2^12, 0x48^34, 0xcd^56, 0xc9^78, 0xc9^12, 0x07^34, 0x00^56 }; +const byte WEBSOCKET_SEND_COMPRESSED_MESSAGE_REUSE_CTX[] = + { 0xc1, 0x85, 12, 34, 56, 78, 0xf2^12, 0x00^34, 0x11^56, 0x00^78, 0x00^12}; +// See same compression example, but where `client_no_context_takeover` is used (saves 2 bytes). +const byte WEBSOCKET_DEFLATE_NO_COMPRESSION_MESSAGE[] = + { 0xc1, 0x0b, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x00 }; +// See this example: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.3 +// This uses a DEFLATE block with no compression. +const byte WEBSOCKET_BFINAL_SET_MESSAGE[] = + { 0xc1, 0x08, 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00, 0x00 }; +// See this example: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.4 +// This uses a DEFLATE block with BFINAL set to 1. +const byte WEBSOCKET_TWO_DEFLATE_BLOCKS_MESSAGE[] = + { 0xc1, 0x0d, 0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff, 0xca, 0xc9, 0xc9, 0x07, 0x00 }; +// See this example: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.5 +// This uses two DEFLATE blocks in a single message. +const byte WEBSOCKET_EMPTY_COMPRESSED_MESSAGE[] = + { 0xc1, 0x01, 0x00 }; +const byte WEBSOCKET_EMPTY_SEND_COMPRESSED_MESSAGE[] = + { 0xc1, 0x81, 12, 34, 56, 78, 0x00^12 }; +#endif // KJ_HAS_ZLIB + template kj::ArrayPtr asBytes(const char (&chars)[s]) { return kj::ArrayPtr(chars, s - 1).asBytes(); @@ -2238,539 +2580,1245 @@ void testWebSocketClient(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, } } -inline kj::Promise writeA(kj::AsyncOutputStream& out, kj::ArrayPtr data) { - return out.write(data.begin(), data.size()); -} - -KJ_TEST("HttpClient WebSocket handshake") { - KJ_HTTP_TEST_SETUP_IO; - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - - auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); - - auto serverTask = expectRead(*pipe.ends[1], request) - .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) - .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) - .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) - .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE); }) - .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) - .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) - .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); +#if KJ_HAS_ZLIB +void testWebSocketTwoMessageCompression(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId extHeader, kj::StringPtr extensions, + HttpClient& client) { + // In this test, the server will always use `server_no_context_takeover` (since we can just reuse + // the message). However, we will modify the client's compressor in different ways to see how the + // compressed message changes. - HttpHeaderTable::Builder tableBuilder; - HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); - auto headerTable = tableBuilder.build(); + kj::HttpHeaders headers(headerTable); + headers.set(extHeader, extensions); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); - FakeEntropySource entropySource; - HttpClientSettings clientSettings; - clientSettings.entropySource = entropySource; + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(extHeader)).startsWith("permessage-deflate")); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); - auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + ws->send(kj::StringPtr("Hello")).wait(waitScope); - testWebSocketClient(waitScope, *headerTable, hMyHeader, *client); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + ws->send(kj::StringPtr("Hello")).wait(waitScope); - serverTask.wait(waitScope); + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } } +#endif // KJ_HAS_ZLIB -KJ_TEST("HttpClient WebSocket error") { - KJ_HTTP_TEST_SETUP_IO; - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - - auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); - - auto serverTask = expectRead(*pipe.ends[1], request) - .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)); }) - .then([&]() { return expectRead(*pipe.ends[1], request); }) - .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)); }) - .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - - HttpHeaderTable::Builder tableBuilder; - HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); - auto headerTable = tableBuilder.build(); - - FakeEntropySource entropySource; - HttpClientSettings clientSettings; - clientSettings.entropySource = entropySource; +#if KJ_HAS_ZLIB +void testWebSocketEmptyMessageCompression(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId extHeader, kj::StringPtr extensions, + HttpClient& client) { + // Confirm that we can send empty messages when compression is enabled. - auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + kj::HttpHeaders headers(headerTable); + headers.set(extHeader, extensions); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); - kj::HttpHeaders headers(*headerTable); - headers.set(hMyHeader, "foo"); + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(extHeader)).startsWith("permessage-deflate")); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); { - auto response = client->openWebSocket("/websocket", headers).wait(waitScope); + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + ws->send(kj::StringPtr("Hello")).wait(waitScope); - KJ_EXPECT(response.statusCode == 404); - KJ_EXPECT(response.statusText == "Not Found", response.statusText); - KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo"); - KJ_ASSERT(response.webSocketOrBody.is>()); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == ""); } + ws->send(kj::StringPtr("")).wait(waitScope); { - auto response = client->openWebSocket("/websocket", headers).wait(waitScope); - - KJ_EXPECT(response.statusCode == 404); - KJ_EXPECT(response.statusText == "Not Found", response.statusText); - KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo"); - KJ_ASSERT(response.webSocketOrBody.is>()); + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); } + ws->send(kj::StringPtr("Hello")).wait(waitScope); - serverTask.wait(waitScope); + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } } +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +void testWebSocketOptimizePumpProxy(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId extHeader, kj::StringPtr extensions, + HttpClient& client) { + // Suppose we are proxying a websocket conversation between a client and a server. + // This looks something like: CLIENT <--> (proxyServer <==PUMP==> proxyClient) <--> SERVER + // + // We want to enable optimizedPumping from the proxy's server (which communicates with the client), + // to the proxy's client (which communicates with the origin server). + // + // For this to work, proxyServer's inbound settings must map to proxyClient's outbound settings + // (and vice versa). In this case, `ws` is `proxyClient`, so we want to take `ws`'s compression + // configuration and pass it to `proxyServer` in a way that would allow for optimizedPumping. -KJ_TEST("HttpServer WebSocket handshake") { - KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + kj::HttpHeaders headers(headerTable); + headers.set(extHeader, extensions); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); - HttpHeaderTable::Builder tableBuilder; - HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); - auto headerTable = tableBuilder.build(); - TestWebSocketService service(*headerTable, hMyHeader); - HttpServer server(timer, *headerTable, service); + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(extHeader)).startsWith("permessage-deflate")); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto maybeExt = ws->getPreferredExtensions(kj::WebSocket::ExtensionsContext::REQUEST); + // Should be nullptr since we are asking `ws` (a client) to give us extensions that we can give to + // another client. Since clients cannot `optimizedPumpTo` each other, we must get null. + KJ_ASSERT(maybeExt == nullptr); - auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); - writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); - expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); + maybeExt = ws->getPreferredExtensions(kj::WebSocket::ExtensionsContext::RESPONSE); + kj::StringPtr extStr = KJ_ASSERT_NONNULL(maybeExt); + KJ_ASSERT(extStr == "permessage-deflate; server_no_context_takeover"); + // We got back the string the client sent! + // We could then pass this string as a header to `acceptWebSocket` and ensure the `proxyServer`s + // inbound settings match the `proxyClient`s outbound settings. - expectRead(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); - writeA(*pipe.ends[1], WEBSOCKET_SEND_MESSAGE).wait(waitScope); - expectRead(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE).wait(waitScope); - writeA(*pipe.ends[1], WEBSOCKET_SEND_CLOSE).wait(waitScope); - expectRead(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE).wait(waitScope); + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } +} +#endif // KJ_HAS_ZLIB +#if KJ_HAS_ZLIB +void testWebSocketFourMessageCompression(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId extHeader, kj::StringPtr extensions, + HttpClient& client) { + // In this test, the server will always use `server_no_context_takeover` (since we can just reuse + // the message). We will receive three messages. - listenTask.wait(waitScope); + kj::HttpHeaders headers(headerTable); + headers.set(extHeader, extensions); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); + + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(extHeader)).startsWith("permessage-deflate")); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); + + for (size_t i = 0; i < 4; i++) { + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + } + + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } } +#endif // KJ_HAS_ZLIB -KJ_TEST("HttpServer WebSocket handshake error") { +inline kj::Promise writeA(kj::AsyncOutputStream& out, kj::ArrayPtr data) { + return out.write(data.begin(), data.size()); +} + +KJ_TEST("HttpClient WebSocket handshake") { KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + HttpHeaderTable::Builder tableBuilder; HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); auto headerTable = tableBuilder.build(); - TestWebSocketService service(*headerTable, hMyHeader); - HttpServer server(timer, *headerTable, service); - - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - auto request = kj::str("GET /return-error", WEBSOCKET_REQUEST_HANDSHAKE); - writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); - expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(waitScope); + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; - // Can send more requests! - writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); - expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(waitScope); + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); - pipe.ends[1]->shutdownWrite(); + testWebSocketClient(waitScope, *headerTable, hMyHeader, *client); - listenTask.wait(waitScope); + serverTask.wait(waitScope); } -// ----------------------------------------------------------------------------- +KJ_TEST("WebSocket Compression String Parsing (splitNext)") { + // Test `splitNext()`. + // We want to assert that: + // If a delimiter is found: + // - `input` is updated to point to the rest of the string after the delimiter. + // - The text before the delimiter is returned. + // If no delimiter is found: + // - `input` is updated to an empty string. + // - The text that had been in `input` is returned. -KJ_TEST("HttpServer request timeout") { - auto PIPELINE_TESTS = pipelineTestCases(); + const auto s = "permessage-deflate; client_max_window_bits=10;server_no_context_takeover"_kj; - KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + const auto expectedPartOne = "permessage-deflate"_kj; + const auto expectedRemainingOne = "client_max_window_bits=10;server_no_context_takeover"_kj; - HttpHeaderTable table; - TestHttpService service(PIPELINE_TESTS, table); - HttpServerSettings settings; - settings.headerTimeout = 1 * kj::MILLISECONDS; - HttpServer server(timer, table, service, settings); + auto cursor = s.asArray(); + auto actual = _::splitNext(cursor, ';'); + KJ_ASSERT(actual == expectedPartOne); - // Shouldn't hang! Should time out. - auto promise = server.listenHttp(kj::mv(pipe.ends[0])); - KJ_EXPECT(!promise.poll(waitScope)); - timer.advanceTo(timer.now() + settings.headerTimeout / 2); - KJ_EXPECT(!promise.poll(waitScope)); - timer.advanceTo(timer.now() + settings.headerTimeout); - promise.wait(waitScope); + _::stripLeadingAndTrailingSpace(cursor); + KJ_ASSERT(cursor == expectedRemainingOne.asArray()); - // Closes the connection without sending anything. - KJ_EXPECT(pipe.ends[1]->readAllText().wait(waitScope) == ""); + const auto expectedPartTwo = "client_max_window_bits=10"_kj; + const auto expectedRemainingTwo = "server_no_context_takeover"_kj; + + actual = _::splitNext(cursor, ';'); + KJ_ASSERT(actual == expectedPartTwo); + KJ_ASSERT(cursor == expectedRemainingTwo); + + const auto expectedPartThree = "server_no_context_takeover"_kj; + const auto expectedRemainingThree = ""_kj; + actual = _::splitNext(cursor, ';'); + KJ_ASSERT(actual == expectedPartThree); + KJ_ASSERT(cursor == expectedRemainingThree); } -KJ_TEST("HttpServer pipeline timeout") { - auto PIPELINE_TESTS = pipelineTestCases(); +KJ_TEST("WebSocket Compression String Parsing (splitParts)") { + // Test `splitParts()`. + // We want to assert that we: + // 1. Correctly split by the delimiter. + // 2. Strip whitespace before/after the extracted part. + const auto permitted = "permessage-deflate"_kj; + + const auto s = "permessage-deflate; client_max_window_bits=10;server_no_context_takeover, " + " permessage-deflate; ; ," // strips leading whitespace + "permessage-deflate"_kj; + + // These are the expected values. + const auto extOne = "permessage-deflate; client_max_window_bits=10;server_no_context_takeover"_kj; + const auto extTwo = "permessage-deflate; ;"_kj; + const auto extThree = "permessage-deflate"_kj; + + auto actualExtensions = kj::_::splitParts(s, ','); + KJ_ASSERT(actualExtensions.size() == 3); + KJ_ASSERT(actualExtensions[0] == extOne); + KJ_ASSERT(actualExtensions[1] == extTwo); + KJ_ASSERT(actualExtensions[2] == extThree); + // Splitting by ',' was fine, now let's try splitting the parameters (split by ';'). + + const auto paramOne = "client_max_window_bits=10"_kj; + const auto paramTwo = "server_no_context_takeover"_kj; + + auto actualParamsFirstExt = kj::_::splitParts(actualExtensions[0], ';'); + KJ_ASSERT(actualParamsFirstExt.size() == 3); + KJ_ASSERT(actualParamsFirstExt[0] == permitted); + KJ_ASSERT(actualParamsFirstExt[1] == paramOne); + KJ_ASSERT(actualParamsFirstExt[2] == paramTwo); + + auto actualParamsSecondExt = kj::_::splitParts(actualExtensions[1], ';'); + KJ_ASSERT(actualParamsSecondExt.size() == 2); + KJ_ASSERT(actualParamsSecondExt[0] == permitted); + KJ_ASSERT(actualParamsSecondExt[1] == ""_kj); // Note that the whitespace was stripped. + + auto actualParamsThirdExt = kj::_::splitParts(actualExtensions[2], ';'); + // No parameters supplied in the third offer. We expect to only see the extension name. + KJ_ASSERT(actualParamsThirdExt.size() == 1); + KJ_ASSERT(actualParamsThirdExt[0] == permitted); +} - KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; +KJ_TEST("WebSocket Compression String Parsing (toKeysAndVals)") { + // If an "=" is found, everything before the "=" goes into the `Key` and everything after goes + // into the `Value`. Otherwise, everything goes into the `Key` and the `Value` remains null. + const auto cleanParameters = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=10"_kj; + auto parts = _::splitParts(cleanParameters, ';'); + auto keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(keysMaybeValues.size() == 3); + + auto firstKey = "client_no_context_takeover"_kj; + KJ_ASSERT(keysMaybeValues[0].key == firstKey.asArray()); + KJ_ASSERT(keysMaybeValues[0].val == nullptr); + + auto secondKey = "client_max_window_bits"_kj; + KJ_ASSERT(keysMaybeValues[1].key == secondKey.asArray()); + KJ_ASSERT(keysMaybeValues[1].val == nullptr); + + auto thirdKey = "server_max_window_bits"_kj; + auto thirdVal = "10"_kj; + KJ_ASSERT(keysMaybeValues[2].key == thirdKey.asArray()); + KJ_ASSERT(keysMaybeValues[2].val == thirdVal.asArray()); + + const auto weirdParameters = "= 14 ; client_max_window_bits= ; server_max_window_bits =hello"_kj; + // This is weird because: + // 1. Parameter 1 has no key. + // 2. Parameter 2 has an "=" but no subsequent value. + // 3. Parameter 3 has an "=" with an invalid value. + // That said, we don't mind if the parameters are weird when calling this function. The point + // is to create KeyMaybeVal pairs and process them later. + + parts = _::splitParts(weirdParameters, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(keysMaybeValues.size() == 3); + + firstKey = ""_kj; + auto firstVal = "14"_kj; + KJ_ASSERT(keysMaybeValues[0].key == firstKey.asArray()); + KJ_ASSERT(keysMaybeValues[0].val == firstVal.asArray()); + + secondKey = "client_max_window_bits"_kj; + auto secondVal = ""_kj; + KJ_ASSERT(keysMaybeValues[1].key == secondKey.asArray()); + KJ_ASSERT(keysMaybeValues[1].val == secondVal.asArray()); + + thirdKey = "server_max_window_bits"_kj; + thirdVal = "hello"_kj; + KJ_ASSERT(keysMaybeValues[2].key == thirdKey.asArray()); + KJ_ASSERT(keysMaybeValues[2].val == thirdVal.asArray()); +} - HttpHeaderTable table; - TestHttpService service(PIPELINE_TESTS, table); - HttpServerSettings settings; - settings.pipelineTimeout = 1 * kj::MILLISECONDS; - HttpServer server(timer, table, service, settings); +KJ_TEST("WebSocket Compression String Parsing (populateUnverifiedConfig)") { + // First we'll cover cases where the `UnverifiedConfig` is successfully constructed, + // which indicates the offer was structured in a parseable way. Next, we'll cover cases where the + // offer is structured incorrectly. + const auto cleanParameters = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=10"_kj; + auto parts = _::splitParts(cleanParameters, ';'); + auto keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + + auto unverified = _::populateUnverifiedConfig(keysMaybeValues); + auto config = KJ_ASSERT_NONNULL(unverified); + KJ_ASSERT(config.clientNoContextTakeover == true); + KJ_ASSERT(config.serverNoContextTakeover == false); + + auto clientBits = KJ_ASSERT_NONNULL(config.clientMaxWindowBits); + KJ_ASSERT(clientBits == ""_kj); + auto serverBits = KJ_ASSERT_NONNULL(config.serverMaxWindowBits); + KJ_ASSERT(serverBits == "10"_kj); + // Valid config can be populated succesfully. + + const auto weirdButValidParameters = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=this_should_be_a_number"_kj; + parts = _::splitParts(weirdButValidParameters, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + + unverified = _::populateUnverifiedConfig(keysMaybeValues); + config = KJ_ASSERT_NONNULL(unverified); + KJ_ASSERT(config.clientNoContextTakeover == true); + KJ_ASSERT(config.serverNoContextTakeover == false); + + clientBits = KJ_ASSERT_NONNULL(config.clientMaxWindowBits); + KJ_ASSERT(clientBits == ""_kj); + serverBits = KJ_ASSERT_NONNULL(config.serverMaxWindowBits); + KJ_ASSERT(serverBits == "this_should_be_a_number"_kj); + // Note that while the value associated with `server_max_window_bits` is not a number, + // `populateUnverifiedConfig` succeeds because the parameter[=value] is generally structured + // correctly. + + // --- HANDLE INCORRECTLY STRUCTURED OFFERS --- + auto invalidKey = "somethingKey; client_max_window_bits;"_kj; + parts = _::splitParts(invalidKey, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to invalid key name + + auto invalidKeyTwo = "client_max_window_bitsJUNK; server_no_context_takeover"_kj; + parts = _::splitParts(invalidKeyTwo, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to invalid key name (invalid characters after valid parameter name). + + auto repeatedKey = "client_no_context_takeover; client_no_context_takeover"_kj; + parts = _::splitParts(repeatedKey, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to repeated key name. + + auto unexpectedValue = "client_no_context_takeover="_kj; + parts = _::splitParts(unexpectedValue, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to value in `x_no_context_takeover` parameter (unexpected value). + + auto unexpectedValueTwo = "client_no_context_takeover= "_kj; + parts = _::splitParts(unexpectedValueTwo, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to value in `x_no_context_takeover` parameter. + + auto emptyValue = "client_max_window_bits="_kj; + parts = _::splitParts(emptyValue, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to empty value in `x_max_window_bits` parameter. + // "Empty" in this case means an "=" was provided, but no subsequent value was provided. + + auto emptyValueTwo = "client_max_window_bits= "_kj; + parts = _::splitParts(emptyValueTwo, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to empty value in `x_max_window_bits` parameter. + // "Empty" in this case means an "=" was provided, but no subsequent value was provided. +} - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); +KJ_TEST("WebSocket Compression String Parsing (validateCompressionConfig)") { + // We've tested `toKeysAndVals()` and `populateUnverifiedConfig()`, so we only need to test + // correctly structured offers/agreements here. + const auto cleanParameters = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=10"_kj; + auto parts = _::splitParts(cleanParameters, ';'); + auto keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + auto maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + auto unverified = KJ_ASSERT_NONNULL(maybeUnverified); + auto maybeValid = _::validateCompressionConfig(kj::mv(unverified), false); // Validate as Server. + auto valid = KJ_ASSERT_NONNULL(maybeValid); + KJ_ASSERT(valid.inboundNoContextTakeover == true); + KJ_ASSERT(valid.outboundNoContextTakeover == false); + auto inboundBits = KJ_ASSERT_NONNULL(valid.inboundMaxWindowBits); + KJ_ASSERT(inboundBits == 15); // `client_max_window_bits` can be empty in an offer. + auto outboundBits = KJ_ASSERT_NONNULL(valid.outboundMaxWindowBits); + KJ_ASSERT(outboundBits == 10); + // Valid config successfully constructed. + + const auto correctStructureButInvalid = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=this_should_be_a_number"_kj; + parts = _::splitParts(correctStructureButInvalid, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + unverified = KJ_ASSERT_NONNULL(maybeUnverified); + maybeValid = _::validateCompressionConfig(kj::mv(unverified), false); // Validate as Server. + KJ_ASSERT(maybeValid == nullptr); + // The config "looks" correct, but the `server_max_window_bits` parameter has an invalid value. + + const auto invalidRange = "client_max_window_bits; server_max_window_bits=18;"_kj; + // `server_max_window_bits` is out of range, decline. + parts = _::splitParts(invalidRange, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + maybeValid = _::validateCompressionConfig(kj::mv(KJ_REQUIRE_NONNULL(maybeUnverified)), false); + KJ_ASSERT(maybeValid == nullptr); + + const auto invalidRangeTwo = "client_max_window_bits=4"_kj; + // `server_max_window_bits` is out of range, decline. + parts = _::splitParts(invalidRangeTwo, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + maybeValid = _::validateCompressionConfig(kj::mv(KJ_REQUIRE_NONNULL(maybeUnverified)), false); + KJ_ASSERT(maybeValid == nullptr); + + const auto invalidRequest = "server_max_window_bits"_kj; + // `sever_max_window_bits` must have a value in a request AND a response. + parts = _::splitParts(invalidRequest, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + maybeValid = _::validateCompressionConfig(kj::mv(KJ_REQUIRE_NONNULL(maybeUnverified)), false); + KJ_ASSERT(maybeValid == nullptr); + + const auto invalidResponse = "client_max_window_bits"_kj; + // `client_max_window_bits` must have a value in a response. + parts = _::splitParts(invalidResponse, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + maybeValid = _::validateCompressionConfig(kj::mv(KJ_REQUIRE_NONNULL(maybeUnverified)), true); + KJ_ASSERT(maybeValid == nullptr); +} - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(waitScope); - expectRead(*pipe.ends[1], PIPELINE_TESTS[0].response.raw).wait(waitScope); +KJ_TEST("WebSocket Compression String Parsing (findValidExtensionOffers)") { + // Test that we can extract only the valid extensions from a string of offers. + constexpr auto extensions = "permessage-deflate; " // Valid offer. + "client_no_context_takeover; " + "client_max_window_bits; " + "server_max_window_bits=10, " + "permessage-deflate; " // Another valid offer. + "client_no_context_takeover; " + "client_max_window_bits, " + "permessage-invalid; " // Invalid ext name. + "client_no_context_takeover, " + "permessage-deflate; " // Invalid parmeter. + "invalid_parameter; " + "client_max_window_bits; " + "server_max_window_bits=10, " + "permessage-deflate; " // Invalid parmeter value. + "server_max_window_bits=should_be_a_number, " + "permessage-deflate; " // Unexpected parmeter value. + "client_max_window_bits=true, " + "permessage-deflate; " // Missing expected parmeter value. + "server_max_window_bits, " + "permessage-deflate; " // Invalid parameter value (too high). + "client_max_window_bits=99, " + "permessage-deflate; " // Invalid parameter value (too low). + "client_max_window_bits=4, " + "permessage-deflate; " // Invalid parameter (repeated). + "client_max_window_bits; " + "client_max_window_bits, " + "permessage-deflate"_kj; // Valid offer (no parameters). + + auto validOffers = _::findValidExtensionOffers(extensions); + KJ_ASSERT(validOffers.size() == 3); + KJ_ASSERT(validOffers[0].outboundNoContextTakeover == true); + KJ_ASSERT(validOffers[0].inboundNoContextTakeover == false); + KJ_ASSERT(validOffers[0].outboundMaxWindowBits == 15); + KJ_ASSERT(validOffers[0].inboundMaxWindowBits == 10); + + KJ_ASSERT(validOffers[1].outboundNoContextTakeover == true); + KJ_ASSERT(validOffers[1].inboundNoContextTakeover == false); + KJ_ASSERT(validOffers[1].outboundMaxWindowBits == 15); + KJ_ASSERT(validOffers[1].inboundMaxWindowBits == nullptr); + + KJ_ASSERT(validOffers[2].outboundNoContextTakeover == false); + KJ_ASSERT(validOffers[2].inboundNoContextTakeover == false); + KJ_ASSERT(validOffers[2].outboundMaxWindowBits == nullptr); + KJ_ASSERT(validOffers[2].inboundMaxWindowBits == nullptr); +} - // Listen task should time out even though we didn't shutdown the socket. - KJ_EXPECT(!listenTask.poll(waitScope)); - timer.advanceTo(timer.now() + settings.pipelineTimeout / 2); - KJ_EXPECT(!listenTask.poll(waitScope)); - timer.advanceTo(timer.now() + settings.pipelineTimeout); - listenTask.wait(waitScope); +KJ_TEST("WebSocket Compression String Parsing (generateExtensionRequest)") { + // Test that we can extract only the valid extensions from a string of offers. + constexpr auto extensions = "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits=10; " + "client_max_window_bits, " + "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits, " + "permessage-deflate"_kj; + constexpr auto EXPECTED = "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits=15; " + "server_max_window_bits=10, " + "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits=15, " + "permessage-deflate"_kj; + auto validOffers = _::findValidExtensionOffers(extensions); + auto extensionRequest = _::generateExtensionRequest(validOffers); + KJ_ASSERT(extensionRequest == EXPECTED); +} - // In this case, no data is sent back. - KJ_EXPECT(pipe.ends[1]->readAllText().wait(waitScope) == ""); +KJ_TEST("WebSocket Compression String Parsing (tryParseExtensionOffers)") { + // Test that we can accept a valid offer from string of offers. + constexpr auto extensions = "permessage-invalid; " // Invalid ext name. + "client_no_context_takeover, " + "permessage-deflate; " // Invalid parmeter. + "invalid_parameter; " + "client_max_window_bits; " + "server_max_window_bits=10, " + "permessage-deflate; " // Invalid parmeter value. + "server_max_window_bits=should_be_a_number, " + "permessage-deflate; " // Unexpected parmeter value. + "client_max_window_bits=true, " + "permessage-deflate; " // Missing expected parmeter value. + "server_max_window_bits, " + "permessage-deflate; " // Invalid parameter value (too high). + "client_max_window_bits=99, " + "permessage-deflate; " // Invalid parameter value (too low). + "client_max_window_bits=4, " + "permessage-deflate; " // Invalid parameter (repeated). + "client_max_window_bits; " + "client_max_window_bits, " + "permessage-deflate; " // Valid offer. + "client_no_context_takeover; " + "client_max_window_bits; " + "server_max_window_bits=10, " + "permessage-deflate; " // Another valid offer. + "client_no_context_takeover; " + "client_max_window_bits, " + "permessage-deflate"_kj; // Valid offer (no parameters). + + auto maybeAccepted = _::tryParseExtensionOffers(extensions); + auto accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == true); + KJ_ASSERT(accepted.outboundMaxWindowBits == 10); + KJ_ASSERT(accepted.inboundMaxWindowBits == 15); + + // Try the second valid offer from the big list above. + auto offerTwo = "permessage-deflate; client_no_context_takeover; client_max_window_bits"_kj; + maybeAccepted = _::tryParseExtensionOffers(offerTwo); + accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == true); + KJ_ASSERT(accepted.outboundMaxWindowBits == nullptr); + KJ_ASSERT(accepted.inboundMaxWindowBits == 15); + + auto offerThree = "permessage-deflate"_kj; // The third valid offer. + maybeAccepted = _::tryParseExtensionOffers(offerThree); + accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == false); + KJ_ASSERT(accepted.outboundMaxWindowBits == nullptr); + KJ_ASSERT(accepted.inboundMaxWindowBits == nullptr); + + auto invalid = "invalid"_kj; // Any of the invalid offers we saw above would return NULL. + maybeAccepted = _::tryParseExtensionOffers(invalid); + KJ_ASSERT(maybeAccepted == nullptr); } -class BrokenHttpService final: public HttpService { - // HttpService that doesn't send a response. -public: - BrokenHttpService() = default; - explicit BrokenHttpService(kj::Exception&& exception): exception(kj::mv(exception)) {} +KJ_TEST("WebSocket Compression String Parsing (tryParseAllExtensionOffers)") { + // We want to test the following: + // 1. We reject all if we don't find an offer we can accept. + // 2. We accept one after iterating over offers that we have to reject. + // 3. We accept an offer with a `server_max_window_bits` parameter if the manual config allows + // it, and choose the smaller "number of bits" (from clients request). + // 4. We accept an offer with a `server_no_context_takeover` parameter if the manual config + // allows it, and choose the smaller "number of bits" (from manual config) from + // `server_max_window_bits`. + constexpr auto serverOnly = "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits = 14; " + "server_no_context_takeover, " + "permessage-deflate; " + "client_no_context_takeover; " + "server_no_context_takeover, " + "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits = 14"_kj; + + constexpr auto acceptLast = "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits = 14; " + "server_no_context_takeover, " + "permessage-deflate; " + "client_no_context_takeover; " + "server_no_context_takeover, " + "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits = 14, " + "permessage-deflate; " // accept this + "client_no_context_takeover"_kj; + + const auto defaultConfig = CompressionParameters(); + // Our default config is equivalent to `permessage-deflate` with no parameters. + + auto maybeAccepted = _::tryParseAllExtensionOffers(serverOnly, defaultConfig); + KJ_ASSERT(maybeAccepted == nullptr); + // Asserts that we rejected all the offers with `server_x` parameters. + + maybeAccepted = _::tryParseAllExtensionOffers(acceptLast, defaultConfig); + auto accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == false); + KJ_ASSERT(accepted.outboundMaxWindowBits == nullptr); + KJ_ASSERT(accepted.inboundMaxWindowBits == nullptr); + // Asserts that we accepted the only offer that did not have a `server_x` parameter. + + const auto allowServerBits = CompressionParameters { + false, + false, + 15, // server_max_window_bits = 15 + nullptr + }; + maybeAccepted = _::tryParseAllExtensionOffers(serverOnly, allowServerBits); + accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == false); + KJ_ASSERT(accepted.outboundMaxWindowBits == 14); // Note that we chose the lower of (14, 15). + KJ_ASSERT(accepted.inboundMaxWindowBits == nullptr); + // Asserts that we accepted an offer that allowed for `server_max_window_bits` AND we chose the + // lower number of bits (in this case, the clients offer of 14). + + const auto allowServerTakeoverAndBits = CompressionParameters { + true, // server_no_context_takeover = true + false, + 13, // server_max_window_bits = 13 + nullptr + }; - kj::Promise request( - HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::AsyncInputStream& requestBody, Response& responseSender) override { - return requestBody.readAllBytes().then([this](kj::Array&&) -> kj::Promise { - KJ_IF_MAYBE(e, exception) { - return kj::cp(*e); - } else { - return kj::READY_NOW; - } - }); - } + maybeAccepted = _::tryParseAllExtensionOffers(serverOnly, allowServerTakeoverAndBits); + accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == true); + KJ_ASSERT(accepted.inboundNoContextTakeover == false); + KJ_ASSERT(accepted.outboundMaxWindowBits == 13); // Note that we chose the lower of (14, 15). + KJ_ASSERT(accepted.inboundMaxWindowBits == nullptr); + // Asserts that we accepted an offer that allowed for `server_no_context_takeover` AND we chose + // the lower number of bits (in this case, the manual config's choice of 13). +} -private: - kj::Maybe exception; -}; +KJ_TEST("WebSocket Compression String Parsing (generateExtensionResponse)") { + // Test that we can extract only the valid extensions from a string of offers. + constexpr auto extensions = "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits=10; " + "client_max_window_bits, " + "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits, " + "permessage-deflate"_kj; + constexpr auto EXPECTED = "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits=15; " + "server_max_window_bits=10"_kj; + auto accepted = _::tryParseExtensionOffers(extensions); + auto extensionResponse = _::generateExtensionResponse(KJ_ASSERT_NONNULL(accepted)); + KJ_ASSERT(extensionResponse == EXPECTED); +} -KJ_TEST("HttpServer no response") { - auto PIPELINE_TESTS = pipelineTestCases(); +KJ_TEST("WebSocket Compression String Parsing (tryParseExtensionAgreement)") { + constexpr auto didNotOffer = "Server failed WebSocket handshake: " + "added Sec-WebSocket-Extensions when client did not offer any."_kj; + constexpr auto tooMany = "Server failed WebSocket handshake: " + "expected exactly one extension (permessage-deflate) but received more than one."_kj; + constexpr auto badExt = "Server failed WebSocket handshake: " + "response included a Sec-WebSocket-Extensions value that was not permessage-deflate."_kj; + constexpr auto badVal = "Server failed WebSocket handshake: " + "the Sec-WebSocket-Extensions header in the Response included an invalid value."_kj; + + constexpr auto tooManyExtensions = "permessage-deflate; client_no_context_takeover; " + "client_max_window_bits; server_max_window_bits=10, " + "permessage-deflate; client_no_context_takeover; client_max_window_bits;"_kj; + + auto maybeAccepted = _::tryParseExtensionAgreement(nullptr, tooManyExtensions); + KJ_ASSERT( + KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == didNotOffer); + + Maybe defaultConfig = CompressionParameters{}; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, tooManyExtensions); + KJ_ASSERT(KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == tooMany); + + constexpr auto invalidExt = "permessage-invalid; " + "client_no_context_takeover; " + "client_max_window_bits; " + "server_max_window_bits=10;"; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, invalidExt); + KJ_ASSERT(KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == badExt); + + constexpr auto invalidVal = "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits; " + "server_max_window_bits=100;"; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, invalidVal); + KJ_ASSERT(KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == badVal); + + constexpr auto missingVal = "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits; " // This must have a value in a Response! + "server_max_window_bits=10;"; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, missingVal); + KJ_ASSERT(KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == badVal); + + constexpr auto valid = "permessage-deflate; client_no_context_takeover; " + "client_max_window_bits=15; server_max_window_bits=10"_kj; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, valid); + auto config = KJ_ASSERT_NONNULL(maybeAccepted.tryGet()); + KJ_ASSERT(config.outboundNoContextTakeover == true); + KJ_ASSERT(config.inboundNoContextTakeover == false); + KJ_ASSERT(config.outboundMaxWindowBits == 15); + KJ_ASSERT(config.inboundMaxWindowBits == 10); + + auto client = CompressionParameters{ true, false, 15, 10 }; + // If the server ignores our `client_no_context_takeover` parameter, we (the client) still use it. + constexpr auto serverIgnores = "permessage-deflate; client_max_window_bits=15; " + "server_max_window_bits=10"_kj; + maybeAccepted = _::tryParseExtensionAgreement(client, serverIgnores); + config = KJ_ASSERT_NONNULL(maybeAccepted.tryGet()); + KJ_ASSERT(config.outboundNoContextTakeover == true); // Note that this is missing in the response. + KJ_ASSERT(config.inboundNoContextTakeover == false); + KJ_ASSERT(config.outboundMaxWindowBits == 15); + KJ_ASSERT(config.inboundMaxWindowBits == 10); +} +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Empty Message Compression") { + // We'll try to send and receive "Hello", then "", followed by "Hello" again. KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service); + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_HANDSHAKE); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_COMPRESSION_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_EMPTY_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_EMPTY_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE_REUSE_CTX); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(waitScope); - auto text = pipe.ends[1]->readAllText().wait(waitScope); + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); - KJ_EXPECT(text == - "HTTP/1.1 500 Internal Server Error\r\n" - "Connection: close\r\n" - "Content-Length: 51\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "ERROR: The HttpService did not generate a response.", text); -} + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; -KJ_TEST("HttpServer disconnected") { - auto PIPELINE_TESTS = pipelineTestCases(); + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + constexpr auto extensions = "permessage-deflate; server_no_context_takeover"_kj; + testWebSocketEmptyMessageCompression(waitScope, *headerTable, extHeader, extensions, *client); + + serverTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Default Compression") { + // We'll try to send and receive "Hello" twice. The second time we receive "Hello", the compressed + // message will be smaller as a result of the client reusing the lookback window. KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(DISCONNECTED, "disconnected")); - HttpServer server(timer, table, service); + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_HANDSHAKE); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_COMPRESSION_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE_REUSE_CTX); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(waitScope); - auto text = pipe.ends[1]->readAllText().wait(waitScope); + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); - KJ_EXPECT(text == "", text); -} + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; -KJ_TEST("HttpServer overloaded") { - auto PIPELINE_TESTS = pipelineTestCases(); + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + constexpr auto extensions = "permessage-deflate; server_no_context_takeover"_kj; + testWebSocketTwoMessageCompression(waitScope, *headerTable, extHeader, extensions, *client); + serverTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Extract Extensions") { KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(OVERLOADED, "overloaded")); - HttpServer server(timer, table, service); + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_HANDSHAKE); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_COMPRESSION_RESPONSE_HANDSHAKE)); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(waitScope); - auto text = pipe.ends[1]->readAllText().wait(waitScope); - - KJ_EXPECT(text.startsWith("HTTP/1.1 503 Service Unavailable"), text); -} - -KJ_TEST("HttpServer unimplemented") { - auto PIPELINE_TESTS = pipelineTestCases(); - - KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(UNIMPLEMENTED, "unimplemented")); - HttpServer server(timer, table, service); + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(waitScope); - auto text = pipe.ends[1]->readAllText().wait(waitScope); + constexpr auto extensions = "permessage-deflate; server_no_context_takeover"_kj; + testWebSocketOptimizePumpProxy(waitScope, *headerTable, extHeader, extensions, *client); - KJ_EXPECT(text.startsWith("HTTP/1.1 501 Not Implemented"), text); + serverTask.wait(waitScope); } +#endif // KJ_HAS_ZLIB -KJ_TEST("HttpServer threw exception") { - auto PIPELINE_TESTS = pipelineTestCases(); - +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Compression (Client Discards Compression Context)") { + // We'll try to send and receive "Hello" twice. The second time we receive "Hello", the compressed + // message will be the same size as the first time, since the client discards the lookback window. KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed")); - HttpServer server(timer, table, service); + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_HANDSHAKE); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], + asBytes(WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(waitScope); - auto text = pipe.ends[1]->readAllText().wait(waitScope); + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); - KJ_EXPECT(text.startsWith("HTTP/1.1 500 Internal Server Error"), text); -} + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; -KJ_TEST("HttpServer bad request") { + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + constexpr auto extensions = + "permessage-deflate; client_no_context_takeover; server_no_context_takeover"_kj; + testWebSocketTwoMessageCompression(waitScope, *headerTable, extHeader, extensions, *client); + + serverTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Compression (Different DEFLATE blocks)") { + // In this test, we'll try to use the following DEFLATE blocks: + // - Two DEFLATE blocks in 1 message. + // - A block with no compression. + // - A block with BFINAL set to 1. + // Then, we'll try to send a normal compressed message following the BFINAL message to ensure we + // can still process messages after receiving BFINAL. KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service); + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_HANDSHAKE); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], + asBytes(WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_TWO_DEFLATE_BLOCKS_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_DEFLATE_NO_COMPRESSION_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_BFINAL_SET_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - static constexpr auto request = "GET / HTTP/1.1\r\nbad request\r\n\r\n"_kj; - auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); - auto response = pipe.ends[1]->readAllText().wait(waitScope); - KJ_EXPECT(writePromise.poll(waitScope)); - writePromise.wait(waitScope); + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); - static constexpr auto expectedResponse = - "HTTP/1.1 400 Bad Request\r\n" - "Connection: close\r\n" - "Content-Length: 53\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "ERROR: The headers sent by your client are not valid."_kj; + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; - KJ_EXPECT(expectedResponse == response, expectedResponse, response); + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + constexpr auto extensions = + "permessage-deflate; client_no_context_takeover; server_no_context_takeover"_kj; + testWebSocketFourMessageCompression(waitScope, *headerTable, extHeader, extensions, *client); + + serverTask.wait(waitScope); } +#endif // KJ_HAS_ZLIB -KJ_TEST("HttpServer invalid method") { +KJ_TEST("HttpClient WebSocket error") { KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service); - - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); - static constexpr auto request = "bad request\r\n\r\n"_kj; - auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); - auto response = pipe.ends[1]->readAllText().wait(waitScope); - KJ_EXPECT(writePromise.poll(waitScope)); - writePromise.wait(waitScope); + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)); }) + .then([&]() { return expectRead(*pipe.ends[1], request); }) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - static constexpr auto expectedResponse = - "HTTP/1.1 501 Not Implemented\r\n" - "Connection: close\r\n" - "Content-Length: 35\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "ERROR: Unrecognized request method."_kj; + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); - KJ_EXPECT(expectedResponse == response, expectedResponse, response); -} + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; -// Ensure that HttpServerSettings can continue to be constexpr. -KJ_UNUSED static constexpr HttpServerSettings STATIC_CONSTEXPR_SETTINGS {}; + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); -class TestErrorHandler: public HttpServerErrorHandler { -public: - kj::Promise handleClientProtocolError( - HttpHeaders::ProtocolError protocolError, kj::HttpService::Response& response) override { - // In a real error handler, you should redact `protocolError.rawContent`. - auto message = kj::str("Saw protocol error: ", protocolError.description, "; rawContent = ", - encodeCEscape(protocolError.rawContent)); - return sendError(400, "Bad Request", kj::mv(message), response); - } + kj::HttpHeaders headers(*headerTable); + headers.set(hMyHeader, "foo"); - kj::Promise handleApplicationError( - kj::Exception exception, kj::Maybe response) override { - return sendError(500, "Internal Server Error", - kj::str("Saw application error: ", exception.getDescription()), response); - } + { + auto response = client->openWebSocket("/websocket", headers).wait(waitScope); - kj::Promise handleNoResponse(kj::HttpService::Response& response) override { - return sendError(500, "Internal Server Error", kj::str("Saw no response."), response); + KJ_EXPECT(response.statusCode == 404); + KJ_EXPECT(response.statusText == "Not Found", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo"); + KJ_ASSERT(response.webSocketOrBody.is>()); } - static TestErrorHandler instance; + { + auto response = client->openWebSocket("/websocket", headers).wait(waitScope); -private: - kj::Promise sendError(uint statusCode, kj::StringPtr statusText, String message, - Maybe response) { - KJ_IF_MAYBE(r, response) { - HttpHeaderTable headerTable; - HttpHeaders headers(headerTable); - auto body = r->send(statusCode, statusText, headers, message.size()); - return body->write(message.begin(), message.size()).attach(kj::mv(body), kj::mv(message)); - } else { - KJ_LOG(ERROR, "Saw an error but too late to report to client."); - return kj::READY_NOW; - } + KJ_EXPECT(response.statusCode == 404); + KJ_EXPECT(response.statusText == "Not Found", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo"); + KJ_ASSERT(response.webSocketOrBody.is>()); } -}; - -TestErrorHandler TestErrorHandler::instance {}; -KJ_TEST("HttpServer no response, custom error handler") { - auto PIPELINE_TESTS = pipelineTestCases(); + serverTask.wait(waitScope); +} +KJ_TEST("HttpServer WebSocket handshake") { KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpServerSettings settings {}; - settings.errorHandler = TestErrorHandler::instance; - - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service, settings); + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + TestWebSocketService service(*headerTable, hMyHeader); + HttpServer server(timer, *headerTable, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(waitScope); - auto text = pipe.ends[1]->readAllText().wait(waitScope); + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); - KJ_EXPECT(text == - "HTTP/1.1 500 Internal Server Error\r\n" - "Connection: close\r\n" - "Content-Length: 16\r\n" - "\r\n" - "Saw no response.", text); -} + expectRead(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); + writeA(*pipe.ends[1], WEBSOCKET_SEND_MESSAGE).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE).wait(waitScope); + writeA(*pipe.ends[1], WEBSOCKET_SEND_CLOSE).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE).wait(waitScope); -KJ_TEST("HttpServer threw exception, custom error handler") { - auto PIPELINE_TESTS = pipelineTestCases(); + listenTask.wait(waitScope); +} +KJ_TEST("HttpServer WebSocket handshake error") { KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpServerSettings settings {}; - settings.errorHandler = TestErrorHandler::instance; - - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed")); - HttpServer server(timer, table, service, settings); + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + TestWebSocketService service(*headerTable, hMyHeader); + HttpServer server(timer, *headerTable, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(waitScope); - auto text = pipe.ends[1]->readAllText().wait(waitScope); - - KJ_EXPECT(text == - "HTTP/1.1 500 Internal Server Error\r\n" - "Connection: close\r\n" - "Content-Length: 29\r\n" - "\r\n" - "Saw application error: failed", text); -} + auto request = kj::str("GET /return-error", WEBSOCKET_REQUEST_HANDSHAKE); + writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(waitScope); -KJ_TEST("HttpServer bad request, custom error handler") { - KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + // Can send more requests! + writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(waitScope); - HttpServerSettings settings {}; - settings.errorHandler = TestErrorHandler::instance; + pipe.ends[1]->shutdownWrite(); - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service, settings); + listenTask.wait(waitScope); +} - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); +void testBadWebSocketHandshake( + WaitScope& waitScope, Timer& timer, StringPtr request, StringPtr response, TwoWayPipe pipe) { + // Write an invalid WebSocket GET request, and expect a particular error response. - static constexpr auto request = "bad request\r\n\r\n"_kj; - auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); - auto response = pipe.ends[1]->readAllText().wait(waitScope); - KJ_EXPECT(writePromise.poll(waitScope)); - writePromise.wait(waitScope); + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + TestWebSocketService service(*headerTable, hMyHeader); - static constexpr auto expectedResponse = - "HTTP/1.1 400 Bad Request\r\n" - "Connection: close\r\n" - "Content-Length: 80\r\n" - "\r\n" - "Saw protocol error: Unrecognized request method.; " - "rawContent = bad request\\000\\n"_kj; + class ErrorHandler final: public HttpServerErrorHandler { + Promise handleApplicationError( + Exception exception, Maybe response) override { + // When I first wrote this, I expected this function to be called, because + // `TestWebSocketService::request()` definitely throws. However, the exception it throws comes + // from `HttpService::Response::acceptWebSocket()`, which stores the fact which it threw a + // WebSocket error. This prevents the HttpServer's listen loop from propagating the exception + // to our HttpServerErrorHandler (i.e., this function), because it assumes the exception is + // related to the WebSocket error response. See `HttpServer::Connection::startLoop()` for + // details. + bool responseWasSent = response == nullptr; + KJ_FAIL_EXPECT("Unexpected application error", responseWasSent, exception); + return READY_NOW; + } + }; - KJ_EXPECT(expectedResponse == response, expectedResponse, response); + ErrorHandler errorHandler; + + HttpServerSettings serverSettings; + serverSettings.errorHandler = errorHandler; + + HttpServer server(timer, *headerTable, service, serverSettings); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], response).wait(waitScope); + + listenTask.wait(waitScope); } -class PartialResponseService final: public HttpService { - // HttpService that sends a partial response then throws. -public: - kj::Promise request( - HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::AsyncInputStream& requestBody, Response& response) override { - return requestBody.readAllBytes() - .then([this,&response](kj::Array&&) -> kj::Promise { - HttpHeaders headers(table); - auto body = response.send(200, "OK", headers, 32); - auto promise = body->write("foo", 3); - return promise.attach(kj::mv(body)).then([]() -> kj::Promise { - return KJ_EXCEPTION(FAILED, "failed"); - }); - }); - } +KJ_TEST("HttpServer WebSocket handshake with unsupported Sec-WebSocket-Version") { + static constexpr auto REQUEST = + "GET /websocket HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 1\r\n" + "My-Header: foo\r\n" + "\r\n"_kj; + + static constexpr auto RESPONSE = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 56\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The requested WebSocket version is not supported."_kj; + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + testBadWebSocketHandshake(waitScope, timer, REQUEST, RESPONSE, KJ_HTTP_TEST_CREATE_2PIPE); +} + +KJ_TEST("HttpServer WebSocket handshake with missing Sec-WebSocket-Key") { + static constexpr auto REQUEST = + "GET /websocket HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Version: 13\r\n" + "My-Header: foo\r\n" + "\r\n"_kj; + + static constexpr auto RESPONSE = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 32\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Missing Sec-WebSocket-Key"_kj; + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + testBadWebSocketHandshake(waitScope, timer, REQUEST, RESPONSE, KJ_HTTP_TEST_CREATE_2PIPE); +} + +KJ_TEST("HttpServer WebSocket with application error after accept") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + class WebSocketApplicationErrorService: public HttpService, public HttpServerErrorHandler { + // Accepts a WebSocket, receives a message, and throws an exception (application error). + + public: + Promise request( + HttpMethod method, kj::StringPtr, const HttpHeaders&, + AsyncInputStream&, Response& response) override { + KJ_ASSERT(method == HttpMethod::GET); + HttpHeaderTable headerTable; + HttpHeaders responseHeaders(headerTable); + auto webSocket = response.acceptWebSocket(responseHeaders); + return webSocket->receive().then([](WebSocket::Message) { + throwRecoverableException(KJ_EXCEPTION(FAILED, "test exception")); + }).attach(kj::mv(webSocket)); + } + + Promise handleApplicationError(Exception exception, Maybe response) override { + // We accepted the WebSocket, so the response was already sent. At one time, we _did_ expose a + // useless Response reference here, so this is a regression test. + bool responseWasSent = response == nullptr; + KJ_EXPECT(responseWasSent); + KJ_EXPECT(exception.getDescription() == "test exception"_kj); + return READY_NOW; + } + }; + + // Set up the HTTP service. + + WebSocketApplicationErrorService service; + + HttpServerSettings serverSettings; + serverSettings.errorHandler = service; + + HttpHeaderTable headerTable; + HttpServer server(timer, headerTable, service, serverSettings); + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Make a client and open a WebSocket to the service. + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + auto client = newHttpClient( + headerTable, *pipe.ends[1], clientSettings); + + HttpHeaders headers(headerTable); + auto webSocketResponse = client->openWebSocket("/websocket"_kj, headers) + .wait(waitScope); + + KJ_ASSERT(webSocketResponse.statusCode == 101); + auto webSocket = kj::mv(KJ_ASSERT_NONNULL(webSocketResponse.webSocketOrBody.tryGet>())); + + webSocket->send("ignored"_kj).wait(waitScope); + + listenTask.wait(waitScope); +} + +// ----------------------------------------------------------------------------- + +KJ_TEST("HttpServer request timeout") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; -private: - kj::Maybe exception; HttpHeaderTable table; -}; + TestHttpService service(PIPELINE_TESTS, table); + HttpServerSettings settings; + settings.headerTimeout = 1 * kj::MILLISECONDS; + HttpServer server(timer, table, service, settings); -KJ_TEST("HttpServer threw exception after starting response") { + // Shouldn't hang! Should time out. + auto promise = server.listenHttp(kj::mv(pipe.ends[0])); + KJ_EXPECT(!promise.poll(waitScope)); + timer.advanceTo(timer.now() + settings.headerTimeout / 2); + KJ_EXPECT(!promise.poll(waitScope)); + timer.advanceTo(timer.now() + settings.headerTimeout); + promise.wait(waitScope); + + // Closes the connection without sending anything. + KJ_EXPECT(pipe.ends[1]->readAllText().wait(waitScope) == ""); +} + +KJ_TEST("HttpServer pipeline timeout") { auto PIPELINE_TESTS = pipelineTestCases(); KJ_HTTP_TEST_SETUP_IO; @@ -2778,46 +3826,52 @@ KJ_TEST("HttpServer threw exception after starting response") { auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; HttpHeaderTable table; - PartialResponseService service; - HttpServer server(timer, table, service); + TestHttpService service(PIPELINE_TESTS, table); + HttpServerSettings settings; + settings.pipelineTimeout = 1 * kj::MILLISECONDS; + HttpServer server(timer, table, service, settings); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - KJ_EXPECT_LOG(ERROR, "HttpService threw exception after generating a partial response"); - // Do one request. pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) .wait(waitScope); - auto text = pipe.ends[1]->readAllText().wait(waitScope); + expectRead(*pipe.ends[1], PIPELINE_TESTS[0].response.raw).wait(waitScope); - KJ_EXPECT(text == - "HTTP/1.1 200 OK\r\n" - "Content-Length: 32\r\n" - "\r\n" - "foo", text); + // Listen task should time out even though we didn't shutdown the socket. + KJ_EXPECT(!listenTask.poll(waitScope)); + timer.advanceTo(timer.now() + settings.pipelineTimeout / 2); + KJ_EXPECT(!listenTask.poll(waitScope)); + timer.advanceTo(timer.now() + settings.pipelineTimeout); + listenTask.wait(waitScope); + + // In this case, no data is sent back. + KJ_EXPECT(pipe.ends[1]->readAllText().wait(waitScope) == ""); } -class PartialResponseNoThrowService final: public HttpService { - // HttpService that sends a partial response then returns without throwing. +class BrokenHttpService final: public HttpService { + // HttpService that doesn't send a response. public: + BrokenHttpService() = default; + explicit BrokenHttpService(kj::Exception&& exception): exception(kj::mv(exception)) {} + kj::Promise request( HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::AsyncInputStream& requestBody, Response& response) override { - return requestBody.readAllBytes() - .then([this,&response](kj::Array&&) -> kj::Promise { - HttpHeaders headers(table); - auto body = response.send(200, "OK", headers, 32); - auto promise = body->write("foo", 3); - return promise.attach(kj::mv(body)); + kj::AsyncInputStream& requestBody, Response& responseSender) override { + return requestBody.readAllBytes().then([this](kj::Array&&) -> kj::Promise { + KJ_IF_MAYBE(e, exception) { + return kj::cp(*e); + } else { + return kj::READY_NOW; + } }); } private: kj::Maybe exception; - HttpHeaderTable table; }; -KJ_TEST("HttpServer failed to write complete response but didn't throw") { +KJ_TEST("HttpServer no response") { auto PIPELINE_TESTS = pipelineTestCases(); KJ_HTTP_TEST_SETUP_IO; @@ -2825,7 +3879,7 @@ KJ_TEST("HttpServer failed to write complete response but didn't throw") { auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; HttpHeaderTable table; - PartialResponseNoThrowService service; + BrokenHttpService service; HttpServer server(timer, table, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); @@ -2836,58 +3890,36 @@ KJ_TEST("HttpServer failed to write complete response but didn't throw") { auto text = pipe.ends[1]->readAllText().wait(waitScope); KJ_EXPECT(text == - "HTTP/1.1 200 OK\r\n" - "Content-Length: 32\r\n" + "HTTP/1.1 500 Internal Server Error\r\n" + "Connection: close\r\n" + "Content-Length: 51\r\n" + "Content-Type: text/plain\r\n" "\r\n" - "foo", text); + "ERROR: The HttpService did not generate a response.", text); } -class SimpleInputStream final: public kj::AsyncInputStream { - // An InputStream that returns bytes out of a static string. - -public: - SimpleInputStream(kj::StringPtr text) - : unread(text.asBytes()) {} +KJ_TEST("HttpServer disconnected") { + auto PIPELINE_TESTS = pipelineTestCases(); - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - size_t amount = kj::min(maxBytes, unread.size()); - memcpy(buffer, unread.begin(), amount); - unread = unread.slice(amount, unread.size()); - return amount; - } + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; -private: - kj::ArrayPtr unread; -}; + HttpHeaderTable table; + BrokenHttpService service(KJ_EXCEPTION(DISCONNECTED, "disconnected")); + HttpServer server(timer, table, service); -class PumpResponseService final: public HttpService { - // HttpService that uses pumpTo() to write a response, without carefully specifying how much to - // pump, but the stream happens to be the right size. -public: - kj::Promise request( - HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::AsyncInputStream& requestBody, Response& response) override { - return requestBody.readAllBytes() - .then([this,&response](kj::Array&&) -> kj::Promise { - HttpHeaders headers(table); - kj::StringPtr text = "Hello, World!"; - auto body = response.send(200, "OK", headers, text.size()); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - auto stream = kj::heap(text); - auto promise = stream->pumpTo(*body); - return promise.attach(kj::mv(body), kj::mv(stream)) - .then([text](uint64_t amount) { - KJ_EXPECT(amount == text.size()); - }); - }); - } + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); -private: - kj::Maybe exception; - HttpHeaderTable table; -}; + KJ_EXPECT(text == "", text); +} -KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") { +KJ_TEST("HttpServer overloaded") { auto PIPELINE_TESTS = pipelineTestCases(); KJ_HTTP_TEST_SETUP_IO; @@ -2895,7 +3927,7 @@ KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") { auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; HttpHeaderTable table; - PumpResponseService service; + BrokenHttpService service(KJ_EXCEPTION(OVERLOADED, "overloaded")); HttpServer server(timer, table, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); @@ -2903,1228 +3935,3244 @@ KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") { // Do one request. pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) .wait(waitScope); - pipe.ends[1]->shutdownWrite(); auto text = pipe.ends[1]->readAllText().wait(waitScope); - KJ_EXPECT(text == - "HTTP/1.1 200 OK\r\n" - "Content-Length: 13\r\n" - "\r\n" - "Hello, World!", text); + KJ_EXPECT(text.startsWith("HTTP/1.1 503 Service Unavailable"), text); } -class HangingHttpService final: public HttpService { - // HttpService that hangs forever. -public: - kj::Promise request( - HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::AsyncInputStream& requestBody, Response& responseSender) override { - kj::Promise result = kj::NEVER_DONE; - ++inFlight; - return result.attach(kj::defer([this]() { - if (--inFlight == 0) { - KJ_IF_MAYBE(f, onCancelFulfiller) { - f->get()->fulfill(); - } - } - })); - } - - kj::Promise onCancel() { - auto paf = kj::newPromiseAndFulfiller(); - onCancelFulfiller = kj::mv(paf.fulfiller); - return kj::mv(paf.promise); - } - - uint inFlight = 0; - -private: - kj::Maybe exception; - kj::Maybe>> onCancelFulfiller; -}; +KJ_TEST("HttpServer unimplemented") { + auto PIPELINE_TESTS = pipelineTestCases(); -KJ_TEST("HttpServer cancels request when client disconnects") { KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; HttpHeaderTable table; - HangingHttpService service; + BrokenHttpService service(KJ_EXCEPTION(UNIMPLEMENTED, "unimplemented")); HttpServer server(timer, table, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - KJ_EXPECT(service.inFlight == 0); - - static constexpr auto request = "GET / HTTP/1.1\r\n\r\n"_kj; - pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope); - - auto cancelPromise = service.onCancel(); - KJ_EXPECT(!cancelPromise.poll(waitScope)); - KJ_EXPECT(service.inFlight == 1); + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); - // Disconnect client and verify server cancels. - pipe.ends[1] = nullptr; - KJ_ASSERT(cancelPromise.poll(waitScope)); - KJ_EXPECT(service.inFlight == 0); - cancelPromise.wait(waitScope); + KJ_EXPECT(text.startsWith("HTTP/1.1 501 Not Implemented"), text); } -// ----------------------------------------------------------------------------- - -KJ_TEST("newHttpService from HttpClient") { +KJ_TEST("HttpServer threw exception") { auto PIPELINE_TESTS = pipelineTestCases(); KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); - auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; - - kj::Promise writeResponsesPromise = kj::READY_NOW; - for (auto& testCase: PIPELINE_TESTS) { - writeResponsesPromise = writeResponsesPromise - .then([&]() { - return expectRead(*backPipe.ends[1], testCase.request.raw); - }).then([&]() { - return backPipe.ends[1]->write(testCase.response.raw.begin(), testCase.response.raw.size()); - }); - } - - { - HttpHeaderTable table; - auto backClient = newHttpClient(table, *backPipe.ends[0]); - auto frontService = newHttpService(*backClient); - HttpServer frontServer(timer, table, *frontService); - auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); - - for (auto& testCase: PIPELINE_TESTS) { - KJ_CONTEXT(testCase.request.raw, testCase.response.raw); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - frontPipe.ends[0]->write(testCase.request.raw.begin(), testCase.request.raw.size()) - .wait(waitScope); + HttpHeaderTable table; + BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed")); + HttpServer server(timer, table, service); - expectRead(*frontPipe.ends[0], testCase.response.raw).wait(waitScope); - } + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - frontPipe.ends[0]->shutdownWrite(); - listenTask.wait(waitScope); - } + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); - backPipe.ends[0]->shutdownWrite(); - writeResponsesPromise.wait(waitScope); + KJ_EXPECT(text.startsWith("HTTP/1.1 500 Internal Server Error"), text); } -KJ_TEST("newHttpService from HttpClient WebSockets") { +KJ_TEST("HttpServer bad request") { KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); - auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; - - auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); - auto writeResponsesPromise = expectRead(*backPipe.ends[1], request) - .then([&]() { return writeA(*backPipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) - .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) - .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) - .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_REPLY_MESSAGE); }) - .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_CLOSE); }) - .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) - .then([&]() { return expectEnd(*backPipe.ends[1]); }) - .then([&]() { backPipe.ends[1]->shutdownWrite(); }) - .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - { - HttpHeaderTable table; - FakeEntropySource entropySource; - HttpClientSettings clientSettings; - clientSettings.entropySource = entropySource; - auto backClientStream = kj::mv(backPipe.ends[0]); - auto backClient = newHttpClient(table, *backClientStream, clientSettings); - auto frontService = newHttpService(*backClient); - HttpServer frontServer(timer, table, *frontService); - auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service); - writeA(*frontPipe.ends[0], request.asBytes()).wait(waitScope); - expectRead(*frontPipe.ends[0], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - expectRead(*frontPipe.ends[0], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); - writeA(*frontPipe.ends[0], WEBSOCKET_SEND_MESSAGE).wait(waitScope); - expectRead(*frontPipe.ends[0], WEBSOCKET_REPLY_MESSAGE).wait(waitScope); - writeA(*frontPipe.ends[0], WEBSOCKET_SEND_CLOSE).wait(waitScope); - expectRead(*frontPipe.ends[0], WEBSOCKET_REPLY_CLOSE).wait(waitScope); + static constexpr auto request = "GET / HTTP/1.1\r\nbad request\r\n\r\n"_kj; + auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); + auto response = pipe.ends[1]->readAllText().wait(waitScope); + KJ_EXPECT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); - frontPipe.ends[0]->shutdownWrite(); - listenTask.wait(waitScope); - } + static constexpr auto expectedResponse = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 53\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The headers sent by your client are not valid."_kj; - writeResponsesPromise.wait(waitScope); + KJ_EXPECT(expectedResponse == response, expectedResponse, response); } -KJ_TEST("newHttpService from HttpClient WebSockets disconnect") { +KJ_TEST("HttpServer invalid method") { KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); - auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); - auto writeResponsesPromise = expectRead(*backPipe.ends[1], request) - .then([&]() { return writeA(*backPipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) - .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) - .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) - .then([&]() { backPipe.ends[1]->shutdownWrite(); }) - .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service); - { - HttpHeaderTable table; - FakeEntropySource entropySource; - HttpClientSettings clientSettings; - clientSettings.entropySource = entropySource; - auto backClient = newHttpClient(table, *backPipe.ends[0], clientSettings); - auto frontService = newHttpService(*backClient); - HttpServer frontServer(timer, table, *frontService); - auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - writeA(*frontPipe.ends[0], request.asBytes()).wait(waitScope); - expectRead(*frontPipe.ends[0], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); + static constexpr auto request = "bad request\r\n\r\n"_kj; + auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); + auto response = pipe.ends[1]->readAllText().wait(waitScope); + KJ_EXPECT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); - expectRead(*frontPipe.ends[0], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); - writeA(*frontPipe.ends[0], WEBSOCKET_SEND_MESSAGE).wait(waitScope); + static constexpr auto expectedResponse = + "HTTP/1.1 501 Not Implemented\r\n" + "Connection: close\r\n" + "Content-Length: 35\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Unrecognized request method."_kj; - KJ_EXPECT(frontPipe.ends[0]->readAllText().wait(waitScope) == ""); + KJ_EXPECT(expectedResponse == response, expectedResponse, response); +} - frontPipe.ends[0]->shutdownWrite(); - listenTask.wait(waitScope); +// Ensure that HttpServerSettings can continue to be constexpr. +KJ_UNUSED static constexpr HttpServerSettings STATIC_CONSTEXPR_SETTINGS {}; + +class TestErrorHandler: public HttpServerErrorHandler { +public: + kj::Promise handleClientProtocolError( + HttpHeaders::ProtocolError protocolError, kj::HttpService::Response& response) override { + // In a real error handler, you should redact `protocolError.rawContent`. + auto message = kj::str("Saw protocol error: ", protocolError.description, "; rawContent = ", + encodeCEscape(protocolError.rawContent)); + return sendError(400, "Bad Request", kj::mv(message), response); } - writeResponsesPromise.wait(waitScope); -} + kj::Promise handleApplicationError( + kj::Exception exception, kj::Maybe response) override { + return sendError(500, "Internal Server Error", + kj::str("Saw application error: ", exception.getDescription()), response); + } -// ----------------------------------------------------------------------------- + kj::Promise handleNoResponse(kj::HttpService::Response& response) override { + return sendError(500, "Internal Server Error", kj::str("Saw no response."), response); + } -KJ_TEST("newHttpClient from HttpService") { - auto PIPELINE_TESTS = pipelineTestCases(); + static TestErrorHandler instance; - KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); +private: + kj::Promise sendError(uint statusCode, kj::StringPtr statusText, String message, + Maybe response) { + KJ_IF_MAYBE(r, response) { + HttpHeaderTable headerTable; + HttpHeaders headers(headerTable); + auto body = r->send(statusCode, statusText, headers, message.size()); + return body->write(message.begin(), message.size()).attach(kj::mv(body), kj::mv(message)); + } else { + KJ_LOG(ERROR, "Saw an error but too late to report to client."); + return kj::READY_NOW; + } + } +}; - HttpHeaderTable table; - TestHttpService service(PIPELINE_TESTS, table); - auto client = newHttpClient(service); +TestErrorHandler TestErrorHandler::instance {}; - for (auto& testCase: PIPELINE_TESTS) { - testHttpClient(waitScope, table, *client, testCase); - } -} +KJ_TEST("HttpServer no response, custom error handler") { + auto PIPELINE_TESTS = pipelineTestCases(); -KJ_TEST("newHttpClient from HttpService WebSockets") { KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpHeaderTable::Builder tableBuilder; - HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); - auto headerTable = tableBuilder.build(); - TestWebSocketService service(*headerTable, hMyHeader); - auto client = newHttpClient(service); + HttpServerSettings settings {}; + settings.errorHandler = TestErrorHandler::instance; - testWebSocketClient(waitScope, *headerTable, hMyHeader, *client); + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service, settings); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text == + "HTTP/1.1 500 Internal Server Error\r\n" + "Connection: close\r\n" + "Content-Length: 16\r\n" + "\r\n" + "Saw no response.", text); } -KJ_TEST("adapted client/server propagates request exceptions like non-adapted client") { +KJ_TEST("HttpServer threw exception, custom error handler") { + auto PIPELINE_TESTS = pipelineTestCases(); + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpServerSettings settings {}; + settings.errorHandler = TestErrorHandler::instance; HttpHeaderTable table; - HttpHeaders headers(table); + BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed")); + HttpServer server(timer, table, service, settings); - class FailingHttpClient final: public HttpClient { - public: - Request request( - HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::Maybe expectedBodySize = nullptr) override { - KJ_FAIL_ASSERT("request_fail"); - } + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - kj::Promise openWebSocket( - kj::StringPtr url, const HttpHeaders& headers) override { - KJ_FAIL_ASSERT("websocket_fail"); - } - }; + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); - auto rawClient = kj::heap(); + KJ_EXPECT(text == + "HTTP/1.1 500 Internal Server Error\r\n" + "Connection: close\r\n" + "Content-Length: 29\r\n" + "\r\n" + "Saw application error: failed", text); +} - auto innerClient = kj::heap(); - auto adaptedService = kj::newHttpService(*innerClient).attach(kj::mv(innerClient)); - auto adaptedClient = kj::newHttpClient(*adaptedService).attach(kj::mv(adaptedService)); +KJ_TEST("HttpServer bad request, custom error handler") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - KJ_EXPECT_THROW_MESSAGE("request_fail", rawClient->request(HttpMethod::POST, "/"_kj, headers)); - KJ_EXPECT_THROW_MESSAGE("request_fail", adaptedClient->request(HttpMethod::POST, "/"_kj, headers)); + HttpServerSettings settings {}; + settings.errorHandler = TestErrorHandler::instance; - KJ_EXPECT_THROW_MESSAGE("websocket_fail", rawClient->openWebSocket("/"_kj, headers)); - KJ_EXPECT_THROW_MESSAGE("websocket_fail", adaptedClient->openWebSocket("/"_kj, headers)); + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service, settings); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + static constexpr auto request = "bad request\r\n\r\n"_kj; + auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); + auto response = pipe.ends[1]->readAllText().wait(waitScope); + KJ_EXPECT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); + + static constexpr auto expectedResponse = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 80\r\n" + "\r\n" + "Saw protocol error: Unrecognized request method.; " + "rawContent = bad request\\000\\n"_kj; + + KJ_EXPECT(expectedResponse == response, expectedResponse, response); } -class DelayedCompletionHttpService final: public HttpService { +class PartialResponseService final: public HttpService { + // HttpService that sends a partial response then throws. public: - DelayedCompletionHttpService(HttpHeaderTable& table, kj::Maybe expectedLength) - : table(table), expectedLength(expectedLength) {} - kj::Promise request( HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, kj::AsyncInputStream& requestBody, Response& response) override { - auto stream = response.send(200, "OK", HttpHeaders(table), expectedLength); - auto promise = stream->write("foo", 3); - return promise.attach(kj::mv(stream)).then([this]() { - return kj::mv(paf.promise); + return requestBody.readAllBytes() + .then([this,&response](kj::Array&&) -> kj::Promise { + HttpHeaders headers(table); + auto body = response.send(200, "OK", headers, 32); + auto promise = body->write("foo", 3); + return promise.attach(kj::mv(body)).then([]() -> kj::Promise { + return KJ_EXCEPTION(FAILED, "failed"); + }); }); } - kj::PromiseFulfiller& getFulfiller() { return *paf.fulfiller; } - private: - HttpHeaderTable& table; - kj::Maybe expectedLength; - kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller(); + kj::Maybe exception; + HttpHeaderTable table; }; -void doDelayedCompletionTest(bool exception, kj::Maybe expectedLength) noexcept { +KJ_TEST("HttpServer threw exception after starting response") { + auto PIPELINE_TESTS = pipelineTestCases(); + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; HttpHeaderTable table; + PartialResponseService service; + HttpServer server(timer, table, service); - DelayedCompletionHttpService service(table, expectedLength); - auto client = newHttpClient(service); - - auto resp = client->request(HttpMethod::GET, "/", HttpHeaders(table), uint64_t(0)) - .response.wait(waitScope); - KJ_EXPECT(resp.statusCode == 200); - - // Read "foo" from the response body: works - char buffer[16]; - KJ_ASSERT(resp.body->tryRead(buffer, 1, sizeof(buffer)).wait(waitScope) == 3); - buffer[3] = '\0'; - KJ_EXPECT(buffer == "foo"_kj); - - // But reading any more hangs. - auto promise = resp.body->tryRead(buffer, 1, sizeof(buffer)); - - KJ_EXPECT(!promise.poll(waitScope)); - - // Until we cause the service to return. - if (exception) { - service.getFulfiller().reject(KJ_EXCEPTION(FAILED, "service-side failure")); - } else { - service.getFulfiller().fulfill(); - } - - KJ_ASSERT(promise.poll(waitScope)); - - if (exception) { - KJ_EXPECT_THROW_MESSAGE("service-side failure", promise.wait(waitScope)); - } else { - promise.wait(waitScope); - } -}; - -KJ_TEST("adapted client waits for service to complete before returning EOF on response stream") { - doDelayedCompletionTest(false, uint64_t(3)); -} - -KJ_TEST("adapted client waits for service to complete before returning EOF on chunked response") { - doDelayedCompletionTest(false, nullptr); -} + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); -KJ_TEST("adapted client propagates throw from service after complete response body sent") { - doDelayedCompletionTest(true, uint64_t(3)); -} + KJ_EXPECT_LOG(ERROR, "HttpService threw exception after generating a partial response"); -KJ_TEST("adapted client propagates throw from service after incomplete response body sent") { - doDelayedCompletionTest(true, uint64_t(6)); -} + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); -KJ_TEST("adapted client propagates throw from service after chunked response body sent") { - doDelayedCompletionTest(true, nullptr); + KJ_EXPECT(text == + "HTTP/1.1 200 OK\r\n" + "Content-Length: 32\r\n" + "\r\n" + "foo", text); } -class DelayedCompletionWebSocketHttpService final: public HttpService { +class PartialResponseNoThrowService final: public HttpService { + // HttpService that sends a partial response then returns without throwing. public: - DelayedCompletionWebSocketHttpService(HttpHeaderTable& table, bool closeUpstreamFirst) - : table(table), closeUpstreamFirst(closeUpstreamFirst) {} - kj::Promise request( HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, kj::AsyncInputStream& requestBody, Response& response) override { - KJ_ASSERT(headers.isWebSocket()); - - auto ws = response.acceptWebSocket(HttpHeaders(table)); - kj::Promise promise = kj::READY_NOW; - if (closeUpstreamFirst) { - // Wait for a close message from the client before starting. - promise = promise.then([&ws = *ws]() { return ws.receive(); }).ignoreResult(); - } - promise = promise - .then([&ws = *ws]() { return ws.send("foo"_kj); }) - .then([&ws = *ws]() { return ws.close(1234, "closed"_kj); }); - if (!closeUpstreamFirst) { - // Wait for a close message from the client at the end. - promise = promise.then([&ws = *ws]() { return ws.receive(); }).ignoreResult(); - } - return promise.attach(kj::mv(ws)).then([this]() { - return kj::mv(paf.promise); + return requestBody.readAllBytes() + .then([this,&response](kj::Array&&) -> kj::Promise { + HttpHeaders headers(table); + auto body = response.send(200, "OK", headers, 32); + auto promise = body->write("foo", 3); + return promise.attach(kj::mv(body)); }); } - kj::PromiseFulfiller& getFulfiller() { return *paf.fulfiller; } - private: - HttpHeaderTable& table; - bool closeUpstreamFirst; - kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller(); + kj::Maybe exception; + HttpHeaderTable table; }; -void doDelayedCompletionWebSocketTest(bool exception, bool closeUpstreamFirst) noexcept { +KJ_TEST("HttpServer failed to write complete response but didn't throw") { + auto PIPELINE_TESTS = pipelineTestCases(); + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; HttpHeaderTable table; + PartialResponseNoThrowService service; + HttpServer server(timer, table, service); - DelayedCompletionWebSocketHttpService service(table, closeUpstreamFirst); - auto client = newHttpClient(service); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - auto resp = client->openWebSocket("/", HttpHeaders(table)).wait(waitScope); - auto ws = kj::mv(KJ_ASSERT_NONNULL(resp.webSocketOrBody.tryGet>())); + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); - if (closeUpstreamFirst) { - // Send "close" immediately. - ws->close(1234, "whatever"_kj).wait(waitScope); - } + KJ_EXPECT(text == + "HTTP/1.1 200 OK\r\n" + "Content-Length: 32\r\n" + "\r\n" + "foo", text); +} - // Read "foo" from the WebSocket: works - { - auto msg = ws->receive().wait(waitScope); - KJ_ASSERT(msg.is()); - KJ_ASSERT(msg.get() == "foo"); - } +class SimpleInputStream final: public kj::AsyncInputStream { + // An InputStream that returns bytes out of a static string. - kj::Promise promise = nullptr; - if (closeUpstreamFirst) { - // Receiving the close hangs. - promise = ws->receive() - .then([](WebSocket::Message&& msg) { KJ_EXPECT(msg.is()); }); - } else { - auto msg = ws->receive().wait(waitScope); - KJ_ASSERT(msg.is()); +public: + SimpleInputStream(kj::StringPtr text) + : unread(text.asBytes()) {} - // Sending a close hangs. - promise = ws->close(1234, "whatever"_kj); + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + size_t amount = kj::min(maxBytes, unread.size()); + memcpy(buffer, unread.begin(), amount); + unread = unread.slice(amount, unread.size()); + return amount; } - KJ_EXPECT(!promise.poll(waitScope)); - // Until we cause the service to return. - if (exception) { - service.getFulfiller().reject(KJ_EXCEPTION(FAILED, "service-side failure")); - } else { - service.getFulfiller().fulfill(); - } +private: + kj::ArrayPtr unread; +}; - KJ_ASSERT(promise.poll(waitScope)); +class PumpResponseService final: public HttpService { + // HttpService that uses pumpTo() to write a response, without carefully specifying how much to + // pump, but the stream happens to be the right size. +public: + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + return requestBody.readAllBytes() + .then([this,&response](kj::Array&&) -> kj::Promise { + HttpHeaders headers(table); + kj::StringPtr text = "Hello, World!"; + auto body = response.send(200, "OK", headers, text.size()); - if (exception) { - KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("service-side failure", promise.wait(waitScope)); - } else { - promise.wait(waitScope); + auto stream = kj::heap(text); + auto promise = stream->pumpTo(*body); + return promise.attach(kj::mv(body), kj::mv(stream)) + .then([text](uint64_t amount) { + KJ_EXPECT(amount == text.size()); + }); + }); } + +private: + kj::Maybe exception; + HttpHeaderTable table; }; -KJ_TEST("adapted client waits for service to complete before completing upstream close on WebSocket") { - doDelayedCompletionWebSocketTest(false, false); -} +KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") { + auto PIPELINE_TESTS = pipelineTestCases(); -KJ_TEST("adapted client waits for service to complete before returning downstream close on WebSocket") { - doDelayedCompletionWebSocketTest(false, true); -} + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; -KJ_TEST("adapted client propagates throw from service after WebSocket upstream close sent") { - doDelayedCompletionWebSocketTest(true, false); -} + HttpHeaderTable table; + PumpResponseService service; + HttpServer server(timer, table, service); -KJ_TEST("adapted client propagates throw from service after WebSocket downstream close sent") { - doDelayedCompletionWebSocketTest(true, true); -} + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); -// ----------------------------------------------------------------------------- + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + pipe.ends[1]->shutdownWrite(); + auto text = pipe.ends[1]->readAllText().wait(waitScope); -class CountingIoStream final: public kj::AsyncIoStream { - // An AsyncIoStream wrapper which decrements a counter when destroyed (allowing us to count how - // many connections are open). + KJ_EXPECT(text == + "HTTP/1.1 200 OK\r\n" + "Content-Length: 13\r\n" + "\r\n" + "Hello, World!", text); +} +class HangingHttpService final: public HttpService { + // HttpService that hangs forever. public: - CountingIoStream(kj::Own inner, uint& count) - : inner(kj::mv(inner)), count(count) {} - ~CountingIoStream() noexcept(false) { - --count; + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& responseSender) override { + kj::Promise result = kj::NEVER_DONE; + ++inFlight; + return result.attach(kj::defer([this]() { + if (--inFlight == 0) { + KJ_IF_MAYBE(f, onCancelFulfiller) { + f->get()->fulfill(); + } + } + })); } - kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { - return inner->read(buffer, minBytes, maxBytes); - } - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return inner->tryRead(buffer, minBytes, maxBytes); - } - kj::Maybe tryGetLength() override { - return inner->tryGetLength();; - } - kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { - return inner->pumpTo(output, amount); - } - kj::Promise write(const void* buffer, size_t size) override { - return inner->write(buffer, size); - } - kj::Promise write(kj::ArrayPtr> pieces) override { - return inner->write(pieces); - } - kj::Maybe> tryPumpFrom( - kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { - return inner->tryPumpFrom(input, amount); - } - Promise whenWriteDisconnected() override { - return inner->whenWriteDisconnected(); - } - void shutdownWrite() override { - return inner->shutdownWrite(); - } - void abortRead() override { - return inner->abortRead(); + kj::Promise onCancel() { + auto paf = kj::newPromiseAndFulfiller(); + onCancelFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); } -public: - kj::Own inner; - uint& count; + uint inFlight = 0; + +private: + kj::Maybe exception; + kj::Maybe>> onCancelFulfiller; }; -class CountingNetworkAddress final: public kj::NetworkAddress { -public: - CountingNetworkAddress(kj::NetworkAddress& inner, uint& count, uint& cumulative) - : inner(inner), count(count), addrCount(ownAddrCount), cumulative(cumulative) {} - CountingNetworkAddress(kj::Own inner, uint& count, uint& addrCount) - : inner(*inner), ownInner(kj::mv(inner)), count(count), addrCount(addrCount), - cumulative(ownCumulative) {} - ~CountingNetworkAddress() noexcept(false) { - --addrCount; - } +KJ_TEST("HttpServer cancels request when client disconnects") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - kj::Promise> connect() override { - ++count; - ++cumulative; - return inner.connect() - .then([this](kj::Own stream) -> kj::Own { - return kj::heap(kj::mv(stream), count); - }); - } + HttpHeaderTable table; + HangingHttpService service; + HttpServer server(timer, table, service); - kj::Own listen() override { KJ_UNIMPLEMENTED("test"); } - kj::Own clone() override { KJ_UNIMPLEMENTED("test"); } - kj::String toString() override { KJ_UNIMPLEMENTED("test"); } + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); -private: - kj::NetworkAddress& inner; - kj::Own ownInner; - uint& count; - uint ownAddrCount = 1; - uint& addrCount; - uint ownCumulative = 0; - uint& cumulative; -}; + KJ_EXPECT(service.inFlight == 0); + + static constexpr auto request = "GET / HTTP/1.1\r\n\r\n"_kj; + pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope); + + auto cancelPromise = service.onCancel(); + KJ_EXPECT(!cancelPromise.poll(waitScope)); + KJ_EXPECT(service.inFlight == 1); + + // Disconnect client and verify server cancels. + pipe.ends[1] = nullptr; + KJ_ASSERT(cancelPromise.poll(waitScope)); + KJ_EXPECT(service.inFlight == 0); + cancelPromise.wait(waitScope); +} + +class SuspendAfter: private HttpService { + // A SuspendableHttpServiceFactory which responds to the first `n` requests with 200 OK, then + // suspends all subsequent requests until its counter is reset. -class ConnectionCountingNetwork final: public kj::Network { public: - ConnectionCountingNetwork(kj::Network& inner, uint& count, uint& addrCount) - : inner(inner), count(count), addrCount(addrCount) {} + void suspendAfter(uint countdownParam) { countdown = countdownParam; } - Promise> parseAddress(StringPtr addr, uint portHint = 0) override { - ++addrCount; - return inner.parseAddress(addr, portHint) - .then([this](Own&& addr) -> Own { - return kj::heap(kj::mv(addr), count, addrCount); - }); - } - Own getSockaddr(const void* sockaddr, uint len) override { - KJ_UNIMPLEMENTED("test"); + kj::Maybe> operator()(HttpServer::SuspendableRequest& sr) { + if (countdown == 0) { + suspendedRequest = sr.suspend(); + return nullptr; + } + --countdown; + return kj::Own(static_cast(this), kj::NullDisposer::instance); } - Own restrictPeers( - kj::ArrayPtr allow, - kj::ArrayPtr deny = nullptr) override { - KJ_UNIMPLEMENTED("test"); + + kj::Maybe getSuspended() { + KJ_DEFER(suspendedRequest = nullptr); + return kj::mv(suspendedRequest); } private: - kj::Network& inner; - uint& count; - uint& addrCount; -}; - -class DummyService final: public HttpService { -public: - DummyService(HttpHeaderTable& headerTable): headerTable(headerTable) {} - kj::Promise request( HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, kj::AsyncInputStream& requestBody, Response& response) override { - if (!headers.isWebSocket()) { - if (url == "/throw") { - return KJ_EXCEPTION(FAILED, "client requested failure"); - } - - auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); - auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size()); - auto promises = kj::heapArrayBuilder>(2); - promises.add(stream->write(body.begin(), body.size())); - promises.add(requestBody.readAllBytes().ignoreResult()); - return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body)); - } else { - auto ws = response.acceptWebSocket(HttpHeaders(headerTable)); - auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); - auto sendPromise = ws->send(body); - - auto promises = kj::heapArrayBuilder>(2); - promises.add(sendPromise.attach(kj::mv(body))); - promises.add(ws->receive().ignoreResult()); - return kj::joinPromises(promises.finish()).attach(kj::mv(ws)); - } + HttpHeaders responseHeaders(table); + response.send(200, "OK", responseHeaders); + return requestBody.readAllBytes().ignoreResult(); } -private: - HttpHeaderTable& headerTable; + HttpHeaderTable table; + + uint countdown = kj::maxValue; + kj::Maybe suspendedRequest; }; -KJ_TEST("HttpClient connection management") { +KJ_TEST("HttpServer can suspend a request") { + // This test sends a single request to an HttpServer three times. First it writes the request to + // its pipe and arranges for the HttpServer to suspend the request. Then it resumes the suspended + // request and arranges for this resumption to be suspended as well. Then it resumes once more and + // arranges for the request to be completed. + KJ_HTTP_TEST_SETUP_IO; - KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - kj::TimerImpl serverTimer(kj::origin()); - kj::TimerImpl clientTimer(kj::origin()); - HttpHeaderTable headerTable; + HttpHeaderTable table; + // This HttpService will not actually be used, because we're passing a factory in to + // listenHttpCleanDrain(). + HangingHttpService service; + HttpServer server(timer, table, service); - DummyService service(headerTable); - HttpServerSettings serverSettings; - HttpServer server(serverTimer, headerTable, service, serverSettings); - auto listenTask = server.listenHttp(*listener); + kj::Maybe suspendedRequest; - uint count = 0; - uint cumulative = 0; - CountingNetworkAddress countingAddr(*addr, count, cumulative); + SuspendAfter factory; - FakeEntropySource entropySource; - HttpClientSettings clientSettings; - clientSettings.entropySource = entropySource; - auto client = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + { + // Observe the HttpServer suspend. - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 0); + factory.suspendAfter(0); + auto listenPromise = server.listenHttpCleanDrain(*pipe.ends[0], factory); - uint i = 0; - auto doRequest = [&]() { - uint n = i++; - return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response - .then([](HttpClient::Response&& response) { - auto promise = response.body->readAllText(); - return promise.attach(kj::mv(response.body)); - }).then([n](kj::String body) { - KJ_EXPECT(body == kj::str("null:/", n)); + static constexpr kj::StringPtr REQUEST = + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "foobar\r\n" + "0\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + // The listen promise is fulfilled with false. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + + // And we have a SuspendedRequest. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } + + { + // Observe the HttpServer suspend again without reading from the connection. + + factory.suspendAfter(0); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); + + // The listen promise is again fulfilled with false. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + + // We again have a suspendedRequest. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } + + { + // The SuspendedRequest is completed. + + factory.suspendAfter(1); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); + + auto drainPromise = kj::evalLast([&]() { + return server.drain(); }); - }; - // We can do several requests in a row and only have one connection. - doRequest().wait(waitScope); - doRequest().wait(waitScope); - doRequest().wait(waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 1); + // We need to read the response for the HttpServer to drain. + auto readPromise = pipe.ends[1]->readAllText(); - // But if we do two in parallel, we'll end up with two connections. - auto req1 = doRequest(); - auto req2 = doRequest(); - req1.wait(waitScope); - req2.wait(waitScope); - KJ_EXPECT(count == 2); - KJ_EXPECT(cumulative == 2); + // This time, the server drained cleanly. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(listenPromise.wait(waitScope)); + + drainPromise.wait(waitScope); + + // Close the server side of the pipe so our read promise completes. + pipe.ends[0] = nullptr; + + auto response = readPromise.wait(waitScope); + static constexpr kj::StringPtr RESPONSE = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n"_kj; + KJ_EXPECT(RESPONSE == response); + } +} + +KJ_TEST("HttpServer can suspend and resume pipelined requests") { + // This test sends multiple requests with both Content-Length and Transfer-Encoding: chunked + // bodies, and verifies that suspending both kinds does not corrupt the stream. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + // This HttpService will not actually be used, because we're passing a factory in to + // listenHttpCleanDrain(). + HangingHttpService service; + HttpServer server(timer, table, service); + + // We'll suspend the second request. + kj::Maybe suspendedRequest; + SuspendAfter factory; + + static constexpr kj::StringPtr LENGTHFUL_REQUEST = + "POST / HTTP/1.1\r\n" + "Content-Length: 6\r\n" + "\r\n" + "foobar"_kj; + static constexpr kj::StringPtr CHUNKED_REQUEST = + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "foobar\r\n" + "0\r\n" + "\r\n"_kj; + + // Set up several requests; we'll suspend and transfer the second and third one. + auto writePromise = pipe.ends[1]->write(LENGTHFUL_REQUEST.begin(), LENGTHFUL_REQUEST.size()) + .then([&]() { + return pipe.ends[1]->write(CHUNKED_REQUEST.begin(), CHUNKED_REQUEST.size()); + }).then([&]() { + return pipe.ends[1]->write(LENGTHFUL_REQUEST.begin(), LENGTHFUL_REQUEST.size()); + }).then([&]() { + return pipe.ends[1]->write(CHUNKED_REQUEST.begin(), CHUNKED_REQUEST.size()); + }); + + auto readPromise = pipe.ends[1]->readAllText(); - // We can reuse after a POST, provided we write the whole POST body properly. { - auto req = client->request( - HttpMethod::POST, kj::str("/foo"), HttpHeaders(headerTable), size_t(6)); - req.body->write("foobar", 6).wait(waitScope); - req.response.wait(waitScope).body->readAllBytes().wait(waitScope); + // Observe the HttpServer suspend the second request. + + factory.suspendAfter(1); + auto listenPromise = server.listenHttpCleanDrain(*pipe.ends[0], factory); + + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); } - KJ_EXPECT(count == 2); - KJ_EXPECT(cumulative == 2); - doRequest().wait(waitScope); - KJ_EXPECT(count == 2); - KJ_EXPECT(cumulative == 2); - // Advance time for half the timeout, then exercise one of the connections. - clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout / 2); - doRequest().wait(waitScope); - doRequest().wait(waitScope); - waitScope.poll(); - KJ_EXPECT(count == 2); - KJ_EXPECT(cumulative == 2); + { + // Let's resume one request and suspend the next pipelined request. - // Advance time past when the other connection should time out. It should be dropped. - clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout * 3 / 4); - waitScope.poll(); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 2); + factory.suspendAfter(1); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); - // Wait for the other to drop. - clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout / 2); - waitScope.poll(); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 2); + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } - // New request creates a new connection again. - doRequest().wait(waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 3); + { + // Resume again and run to completion. - // WebSocket connections are not reused. - client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable)) - .wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 3); + factory.suspendAfter(kj::maxValue); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); - // Errored connections are not reused. - doRequest().wait(waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 4); - client->request(HttpMethod::GET, kj::str("/throw"), HttpHeaders(headerTable)).response - .wait(waitScope).body->readAllBytes().wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 4); + auto drainPromise = kj::evalLast([&]() { + return server.drain(); + }); - // Connections where we failed to read the full response body are not reused. - doRequest().wait(waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 5); - client->request(HttpMethod::GET, kj::str("/foo"), HttpHeaders(headerTable)).response - .wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 5); + // This time, the server drained cleanly. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(listenPromise.wait(waitScope)); + // No suspended request this time. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest == nullptr); - // Connections where we didn't even wait for the response headers are not reused. - doRequest().wait(waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 6); - client->request(HttpMethod::GET, kj::str("/foo"), HttpHeaders(headerTable)); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 6); + drainPromise.wait(waitScope); + } + + writePromise.wait(waitScope); + + // Close the server side of the pipe so our read promise completes. + pipe.ends[0] = nullptr; + + auto responses = readPromise.wait(waitScope); + static constexpr kj::StringPtr RESPONSE = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n"_kj; + KJ_EXPECT(kj::str(kj::delimited(kj::repeat(RESPONSE, 4), "")) == responses); +} + +KJ_TEST("HttpServer can suspend a request with no leftover") { + // This test verifies that if the request loop's read perfectly ends at the end of message + // headers, leaving no leftover section, we can still successfully suspend and resume. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + // This HttpService will not actually be used, because we're passing a factory in to + // listenHttpCleanDrain(). + HangingHttpService service; + HttpServer server(timer, table, service); + + kj::Maybe suspendedRequest; + + SuspendAfter factory; + + { + factory.suspendAfter(0); + auto listenPromise = server.listenHttpCleanDrain(*pipe.ends[0], factory); + + static constexpr kj::StringPtr REQUEST_HEADERS = + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST_HEADERS.begin(), REQUEST_HEADERS.size()).wait(waitScope); + + // The listen promise is fulfilled with false. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + + // And we have a SuspendedRequest. We know that it has no leftover, because we only wrote + // headers, no body yet. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } + + { + factory.suspendAfter(1); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); + + auto drainPromise = kj::evalLast([&]() { + return server.drain(); + }); + + // We need to read the response for the HttpServer to drain. + auto readPromise = pipe.ends[1]->readAllText(); + + static constexpr kj::StringPtr REQUEST_BODY = + "6\r\n" + "foobar\r\n" + "0\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST_BODY.begin(), REQUEST_BODY.size()).wait(waitScope); + + // Clean drain. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(listenPromise.wait(waitScope)); + + drainPromise.wait(waitScope); + + // No SuspendedRequest. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest == nullptr); + + // Close the server side of the pipe so our read promise completes. + pipe.ends[0] = nullptr; + + auto response = readPromise.wait(waitScope); + static constexpr kj::StringPtr RESPONSE = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n"_kj; + KJ_EXPECT(RESPONSE == response); + } +} + +KJ_TEST("HttpServer::listenHttpCleanDrain() factory-created services outlive requests") { + // Test that the lifetimes of factory-created Own objects are handled correctly. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + // This HttpService will not actually be used, because we're passing a factory in to + // listenHttpCleanDrain(). + HangingHttpService service; + HttpServer server(timer, table, service); + + uint serviceCount = 0; + + // A factory which returns a service whose request() function responds asynchronously. + auto factory = [&](HttpServer::SuspendableRequest&) -> kj::Own { + class ServiceImpl final: public HttpService { + public: + explicit ServiceImpl(uint& serviceCount): serviceCount(++serviceCount) {} + ~ServiceImpl() noexcept(false) { --serviceCount; } + KJ_DISALLOW_COPY_AND_MOVE(ServiceImpl); + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + return evalLater([&serviceCount = serviceCount, &table = table, &requestBody, &response]() { + // This KJ_EXPECT here is the entire point of this test. + KJ_EXPECT(serviceCount == 1) + HttpHeaders responseHeaders(table); + response.send(200, "OK", responseHeaders); + return requestBody.readAllBytes().ignoreResult(); + }); + } + + private: + HttpHeaderTable table; + + uint& serviceCount; + }; + + return kj::heap(serviceCount); + }; + + auto listenPromise = server.listenHttpCleanDrain(*pipe.ends[0], factory); + + static constexpr kj::StringPtr REQUEST = + "POST / HTTP/1.1\r\n" + "Content-Length: 6\r\n" + "\r\n" + "foobar"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + // We need to read the response for the HttpServer to drain. + auto readPromise = pipe.ends[1]->readAllText(); + + // http-socketpair-test quirk: we must drive the request loop past the point of receiving request + // headers so that our call to server.drain() doesn't prematurely cancel the request. + KJ_EXPECT(!listenPromise.poll(waitScope)); + + auto drainPromise = kj::evalLast([&]() { + return server.drain(); + }); + + // Clean drain. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(listenPromise.wait(waitScope)); + + drainPromise.wait(waitScope); + + // Close the server side of the pipe so our read promise completes. + pipe.ends[0] = nullptr; + auto response = readPromise.wait(waitScope); + + static constexpr kj::StringPtr RESPONSE = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n"_kj; + KJ_EXPECT(RESPONSE == response); +} + +// ----------------------------------------------------------------------------- + +KJ_TEST("newHttpService from HttpClient") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; + auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::Promise writeResponsesPromise = kj::READY_NOW; + for (auto& testCase: PIPELINE_TESTS) { + writeResponsesPromise = writeResponsesPromise + .then([&]() { + return expectRead(*backPipe.ends[1], testCase.request.raw); + }).then([&]() { + return backPipe.ends[1]->write(testCase.response.raw.begin(), testCase.response.raw.size()); + }); + } + + { + HttpHeaderTable table; + auto backClient = newHttpClient(table, *backPipe.ends[0]); + auto frontService = newHttpService(*backClient); + HttpServer frontServer(timer, table, *frontService); + auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); + + for (auto& testCase: PIPELINE_TESTS) { + KJ_CONTEXT(testCase.request.raw, testCase.response.raw); + + frontPipe.ends[0]->write(testCase.request.raw.begin(), testCase.request.raw.size()) + .wait(waitScope); + + expectRead(*frontPipe.ends[0], testCase.response.raw).wait(waitScope); + } + + frontPipe.ends[0]->shutdownWrite(); + listenTask.wait(waitScope); + } + + backPipe.ends[0]->shutdownWrite(); + writeResponsesPromise.wait(waitScope); +} + +KJ_TEST("newHttpService from HttpClient WebSockets") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; + auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + auto writeResponsesPromise = expectRead(*backPipe.ends[1], request) + .then([&]() { return writeA(*backPipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) + .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) + .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_REPLY_MESSAGE); }) + .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .then([&]() { return expectEnd(*backPipe.ends[1]); }) + .then([&]() { backPipe.ends[1]->shutdownWrite(); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + { + HttpHeaderTable table; + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + auto backClientStream = kj::mv(backPipe.ends[0]); + auto backClient = newHttpClient(table, *backClientStream, clientSettings); + auto frontService = newHttpService(*backClient); + HttpServer frontServer(timer, table, *frontService); + auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); + + writeA(*frontPipe.ends[0], request.asBytes()).wait(waitScope); + expectRead(*frontPipe.ends[0], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); + + expectRead(*frontPipe.ends[0], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); + writeA(*frontPipe.ends[0], WEBSOCKET_SEND_MESSAGE).wait(waitScope); + expectRead(*frontPipe.ends[0], WEBSOCKET_REPLY_MESSAGE).wait(waitScope); + writeA(*frontPipe.ends[0], WEBSOCKET_SEND_CLOSE).wait(waitScope); + expectRead(*frontPipe.ends[0], WEBSOCKET_REPLY_CLOSE).wait(waitScope); + + frontPipe.ends[0]->shutdownWrite(); + listenTask.wait(waitScope); + } + + writeResponsesPromise.wait(waitScope); +} + +KJ_TEST("newHttpService from HttpClient WebSockets disconnect") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; + auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + auto writeResponsesPromise = expectRead(*backPipe.ends[1], request) + .then([&]() { return writeA(*backPipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) + .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) + .then([&]() { backPipe.ends[1]->shutdownWrite(); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + { + HttpHeaderTable table; + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + auto backClient = newHttpClient(table, *backPipe.ends[0], clientSettings); + auto frontService = newHttpService(*backClient); + HttpServer frontServer(timer, table, *frontService); + auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); + + writeA(*frontPipe.ends[0], request.asBytes()).wait(waitScope); + expectRead(*frontPipe.ends[0], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); + + expectRead(*frontPipe.ends[0], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); + writeA(*frontPipe.ends[0], WEBSOCKET_SEND_MESSAGE).wait(waitScope); + + KJ_EXPECT(frontPipe.ends[0]->readAllText().wait(waitScope) == ""); + + frontPipe.ends[0]->shutdownWrite(); + listenTask.wait(waitScope); + } + + writeResponsesPromise.wait(waitScope); +} + +// ----------------------------------------------------------------------------- + +KJ_TEST("newHttpClient from HttpService") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + TestHttpService service(PIPELINE_TESTS, table); + auto client = newHttpClient(service); + + for (auto& testCase: PIPELINE_TESTS) { + testHttpClient(waitScope, table, *client, testCase); + } +} + +KJ_TEST("newHttpClient from HttpService WebSockets") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + TestWebSocketService service(*headerTable, hMyHeader); + auto client = newHttpClient(service); + + testWebSocketClient(waitScope, *headerTable, hMyHeader, *client); +} + +KJ_TEST("adapted client/server propagates request exceptions like non-adapted client") { + KJ_HTTP_TEST_SETUP_IO; + + HttpHeaderTable table; + HttpHeaders headers(table); + + class FailingHttpClient final: public HttpClient { + public: + Request request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_FAIL_ASSERT("request_fail"); + } + + kj::Promise openWebSocket( + kj::StringPtr url, const HttpHeaders& headers) override { + KJ_FAIL_ASSERT("websocket_fail"); + } + }; + + auto rawClient = kj::heap(); + + auto innerClient = kj::heap(); + auto adaptedService = kj::newHttpService(*innerClient).attach(kj::mv(innerClient)); + auto adaptedClient = kj::newHttpClient(*adaptedService).attach(kj::mv(adaptedService)); + + KJ_EXPECT_THROW_MESSAGE("request_fail", rawClient->request(HttpMethod::POST, "/"_kj, headers)); + KJ_EXPECT_THROW_MESSAGE("request_fail", adaptedClient->request(HttpMethod::POST, "/"_kj, headers)); + + KJ_EXPECT_THROW_MESSAGE("websocket_fail", rawClient->openWebSocket("/"_kj, headers)); + KJ_EXPECT_THROW_MESSAGE("websocket_fail", adaptedClient->openWebSocket("/"_kj, headers)); +} + +class DelayedCompletionHttpService final: public HttpService { +public: + DelayedCompletionHttpService(HttpHeaderTable& table, kj::Maybe expectedLength) + : table(table), expectedLength(expectedLength) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + auto stream = response.send(200, "OK", HttpHeaders(table), expectedLength); + auto promise = stream->write("foo", 3); + return promise.attach(kj::mv(stream)).then([this]() { + return kj::mv(paf.promise); + }); + } + + kj::PromiseFulfiller& getFulfiller() { return *paf.fulfiller; } + +private: + HttpHeaderTable& table; + kj::Maybe expectedLength; + kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller(); +}; + +void doDelayedCompletionTest(bool exception, kj::Maybe expectedLength) noexcept { + KJ_HTTP_TEST_SETUP_IO; + + HttpHeaderTable table; + + DelayedCompletionHttpService service(table, expectedLength); + auto client = newHttpClient(service); + + auto resp = client->request(HttpMethod::GET, "/", HttpHeaders(table), uint64_t(0)) + .response.wait(waitScope); + KJ_EXPECT(resp.statusCode == 200); + + // Read "foo" from the response body: works + char buffer[16]; + KJ_ASSERT(resp.body->tryRead(buffer, 1, sizeof(buffer)).wait(waitScope) == 3); + buffer[3] = '\0'; + KJ_EXPECT(buffer == "foo"_kj); + + // But reading any more hangs. + auto promise = resp.body->tryRead(buffer, 1, sizeof(buffer)); + + KJ_EXPECT(!promise.poll(waitScope)); + + // Until we cause the service to return. + if (exception) { + service.getFulfiller().reject(KJ_EXCEPTION(FAILED, "service-side failure")); + } else { + service.getFulfiller().fulfill(); + } + + KJ_ASSERT(promise.poll(waitScope)); + + if (exception) { + KJ_EXPECT_THROW_MESSAGE("service-side failure", promise.wait(waitScope)); + } else { + promise.wait(waitScope); + } +}; + +KJ_TEST("adapted client waits for service to complete before returning EOF on response stream") { + doDelayedCompletionTest(false, uint64_t(3)); +} + +KJ_TEST("adapted client waits for service to complete before returning EOF on chunked response") { + doDelayedCompletionTest(false, nullptr); +} + +KJ_TEST("adapted client propagates throw from service after complete response body sent") { + doDelayedCompletionTest(true, uint64_t(3)); +} + +KJ_TEST("adapted client propagates throw from service after incomplete response body sent") { + doDelayedCompletionTest(true, uint64_t(6)); +} + +KJ_TEST("adapted client propagates throw from service after chunked response body sent") { + doDelayedCompletionTest(true, nullptr); +} + +class DelayedCompletionWebSocketHttpService final: public HttpService { +public: + DelayedCompletionWebSocketHttpService(HttpHeaderTable& table, bool closeUpstreamFirst) + : table(table), closeUpstreamFirst(closeUpstreamFirst) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_ASSERT(headers.isWebSocket()); + + auto ws = response.acceptWebSocket(HttpHeaders(table)); + kj::Promise promise = kj::READY_NOW; + if (closeUpstreamFirst) { + // Wait for a close message from the client before starting. + promise = promise.then([&ws = *ws]() { return ws.receive(); }).ignoreResult(); + } + promise = promise + .then([&ws = *ws]() { return ws.send("foo"_kj); }) + .then([&ws = *ws]() { return ws.close(1234, "closed"_kj); }); + if (!closeUpstreamFirst) { + // Wait for a close message from the client at the end. + promise = promise.then([&ws = *ws]() { return ws.receive(); }).ignoreResult(); + } + return promise.attach(kj::mv(ws)).then([this]() { + return kj::mv(paf.promise); + }); + } + + kj::PromiseFulfiller& getFulfiller() { return *paf.fulfiller; } + +private: + HttpHeaderTable& table; + bool closeUpstreamFirst; + kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller(); +}; + +void doDelayedCompletionWebSocketTest(bool exception, bool closeUpstreamFirst) noexcept { + KJ_HTTP_TEST_SETUP_IO; + + HttpHeaderTable table; + + DelayedCompletionWebSocketHttpService service(table, closeUpstreamFirst); + auto client = newHttpClient(service); + + auto resp = client->openWebSocket("/", HttpHeaders(table)).wait(waitScope); + auto ws = kj::mv(KJ_ASSERT_NONNULL(resp.webSocketOrBody.tryGet>())); + + if (closeUpstreamFirst) { + // Send "close" immediately. + ws->close(1234, "whatever"_kj).wait(waitScope); + } + + // Read "foo" from the WebSocket: works + { + auto msg = ws->receive().wait(waitScope); + KJ_ASSERT(msg.is()); + KJ_ASSERT(msg.get() == "foo"); + } + + kj::Promise promise = nullptr; + if (closeUpstreamFirst) { + // Receiving the close hangs. + promise = ws->receive() + .then([](WebSocket::Message&& msg) { KJ_EXPECT(msg.is()); }); + } else { + auto msg = ws->receive().wait(waitScope); + KJ_ASSERT(msg.is()); + + // Sending a close hangs. + promise = ws->close(1234, "whatever"_kj); + } + KJ_EXPECT(!promise.poll(waitScope)); + + // Until we cause the service to return. + if (exception) { + service.getFulfiller().reject(KJ_EXCEPTION(FAILED, "service-side failure")); + } else { + service.getFulfiller().fulfill(); + } + + KJ_ASSERT(promise.poll(waitScope)); + + if (exception) { + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("service-side failure", promise.wait(waitScope)); + } else { + promise.wait(waitScope); + } +}; + +KJ_TEST("adapted client waits for service to complete before completing upstream close on WebSocket") { + doDelayedCompletionWebSocketTest(false, false); +} + +KJ_TEST("adapted client waits for service to complete before returning downstream close on WebSocket") { + doDelayedCompletionWebSocketTest(false, true); +} + +KJ_TEST("adapted client propagates throw from service after WebSocket upstream close sent") { + doDelayedCompletionWebSocketTest(true, false); +} + +KJ_TEST("adapted client propagates throw from service after WebSocket downstream close sent") { + doDelayedCompletionWebSocketTest(true, true); +} + +// ----------------------------------------------------------------------------- + +class CountingIoStream final: public kj::AsyncIoStream { + // An AsyncIoStream wrapper which decrements a counter when destroyed (allowing us to count how + // many connections are open). + +public: + CountingIoStream(kj::Own inner, uint& count) + : inner(kj::mv(inner)), count(count) {} + ~CountingIoStream() noexcept(false) { + --count; + } + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->read(buffer, minBytes, maxBytes); + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->tryRead(buffer, minBytes, maxBytes); + } + kj::Maybe tryGetLength() override { + return inner->tryGetLength();; + } + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return inner->pumpTo(output, amount); + } + kj::Promise write(const void* buffer, size_t size) override { + return inner->write(buffer, size); + } + kj::Promise write(kj::ArrayPtr> pieces) override { + return inner->write(pieces); + } + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + return inner->tryPumpFrom(input, amount); + } + Promise whenWriteDisconnected() override { + return inner->whenWriteDisconnected(); + } + void shutdownWrite() override { + return inner->shutdownWrite(); + } + void abortRead() override { + return inner->abortRead(); + } + +public: + kj::Own inner; + uint& count; +}; + +class CountingNetworkAddress final: public kj::NetworkAddress { +public: + CountingNetworkAddress(kj::NetworkAddress& inner, uint& count, uint& cumulative) + : inner(inner), count(count), addrCount(ownAddrCount), cumulative(cumulative) {} + CountingNetworkAddress(kj::Own inner, uint& count, uint& addrCount) + : inner(*inner), ownInner(kj::mv(inner)), count(count), addrCount(addrCount), + cumulative(ownCumulative) {} + ~CountingNetworkAddress() noexcept(false) { + --addrCount; + } + + kj::Promise> connect() override { + ++count; + ++cumulative; + return inner.connect() + .then([this](kj::Own stream) -> kj::Own { + return kj::heap(kj::mv(stream), count); + }); + } + + kj::Own listen() override { KJ_UNIMPLEMENTED("test"); } + kj::Own clone() override { KJ_UNIMPLEMENTED("test"); } + kj::String toString() override { KJ_UNIMPLEMENTED("test"); } + +private: + kj::NetworkAddress& inner; + kj::Own ownInner; + uint& count; + uint ownAddrCount = 1; + uint& addrCount; + uint ownCumulative = 0; + uint& cumulative; +}; + +class ConnectionCountingNetwork final: public kj::Network { +public: + ConnectionCountingNetwork(kj::Network& inner, uint& count, uint& addrCount) + : inner(inner), count(count), addrCount(addrCount) {} + + Promise> parseAddress(StringPtr addr, uint portHint = 0) override { + ++addrCount; + return inner.parseAddress(addr, portHint) + .then([this](Own&& addr) -> Own { + return kj::heap(kj::mv(addr), count, addrCount); + }); + } + Own getSockaddr(const void* sockaddr, uint len) override { + KJ_UNIMPLEMENTED("test"); + } + Own restrictPeers( + kj::ArrayPtr allow, + kj::ArrayPtr deny = nullptr) override { + KJ_UNIMPLEMENTED("test"); + } + +private: + kj::Network& inner; + uint& count; + uint& addrCount; +}; + +class DummyService final: public HttpService { +public: + DummyService(HttpHeaderTable& headerTable): headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + if (!headers.isWebSocket()) { + if (url == "/throw") { + return KJ_EXCEPTION(FAILED, "client requested failure"); + } + + auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); + auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size()); + auto promises = kj::heapArrayBuilder>(2); + promises.add(stream->write(body.begin(), body.size())); + promises.add(requestBody.readAllBytes().ignoreResult()); + return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body)); + } else { + auto ws = response.acceptWebSocket(HttpHeaders(headerTable)); + auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); + auto sendPromise = ws->send(body); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(sendPromise.attach(kj::mv(body))); + promises.add(ws->receive().ignoreResult()); + return kj::joinPromises(promises.finish()).attach(kj::mv(ws)); + } + } + +private: + HttpHeaderTable& headerTable; +}; + +KJ_TEST("HttpClient connection management") { + KJ_HTTP_TEST_SETUP_IO; + KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; + + kj::TimerImpl serverTimer(kj::origin()); + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + + DummyService service(headerTable); + HttpServerSettings serverSettings; + HttpServer server(serverTimer, headerTable, service, serverSettings); + auto listenTask = server.listenHttp(*listener); + + uint count = 0; + uint cumulative = 0; + CountingNetworkAddress countingAddr(*addr, count, cumulative); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + auto client = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 0); + + uint i = 0; + auto doRequest = [&]() { + uint n = i++; + return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response + .then([](HttpClient::Response&& response) { + auto promise = response.body->readAllText(); + return promise.attach(kj::mv(response.body)); + }).then([n](kj::String body) { + KJ_EXPECT(body == kj::str("null:/", n)); + }); + }; + + // We can do several requests in a row and only have one connection. + doRequest().wait(waitScope); + doRequest().wait(waitScope); + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 1); + + // But if we do two in parallel, we'll end up with two connections. + auto req1 = doRequest(); + auto req2 = doRequest(); + req1.wait(waitScope); + req2.wait(waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(cumulative == 2); + + // We can reuse after a POST, provided we write the whole POST body properly. + { + auto req = client->request( + HttpMethod::POST, kj::str("/foo"), HttpHeaders(headerTable), size_t(6)); + req.body->write("foobar", 6).wait(waitScope); + req.response.wait(waitScope).body->readAllBytes().wait(waitScope); + } + KJ_EXPECT(count == 2); + KJ_EXPECT(cumulative == 2); + doRequest().wait(waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(cumulative == 2); + + // Advance time for half the timeout, then exercise one of the connections. + clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout / 2); + doRequest().wait(waitScope); + doRequest().wait(waitScope); + waitScope.poll(); + KJ_EXPECT(count == 2); + KJ_EXPECT(cumulative == 2); + + // Advance time past when the other connection should time out. It should be dropped. + clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout * 3 / 4); + waitScope.poll(); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 2); + + // Wait for the other to drop. + clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout / 2); + waitScope.poll(); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 2); + + // New request creates a new connection again. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 3); + + // WebSocket connections are not reused. + client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable)) + .wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 3); + + // Errored connections are not reused. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 4); + client->request(HttpMethod::GET, kj::str("/throw"), HttpHeaders(headerTable)).response + .wait(waitScope).body->readAllBytes().wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 4); + + // Connections where we failed to read the full response body are not reused. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 5); + client->request(HttpMethod::GET, kj::str("/foo"), HttpHeaders(headerTable)).response + .wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 5); + + // Connections where we didn't even wait for the response headers are not reused. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 6); + client->request(HttpMethod::GET, kj::str("/foo"), HttpHeaders(headerTable)); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 6); + + // Connections where we failed to write the full request body are not reused. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 7); + client->request(HttpMethod::POST, kj::str("/foo"), HttpHeaders(headerTable), size_t(6)).response + .wait(waitScope).body->readAllBytes().wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 7); + + // If the server times out the connection, we figure it out on the client. + doRequest().wait(waitScope); + + // TODO(someday): Figure out why the following poll is necessary for the test to pass on Windows + // and Mac. Without it, it seems that the request's connection never starts, so the + // subsequent advanceTo() does not actually time out the connection. + waitScope.poll(); + + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 8); + serverTimer.advanceTo(serverTimer.now() + serverSettings.pipelineTimeout * 2); + waitScope.poll(); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 8); + + // Can still make requests. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 9); +} + +KJ_TEST("HttpClient disable connection reuse") { + KJ_HTTP_TEST_SETUP_IO; + KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; + + kj::TimerImpl serverTimer(kj::origin()); + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + + DummyService service(headerTable); + HttpServerSettings serverSettings; + HttpServer server(serverTimer, headerTable, service, serverSettings); + auto listenTask = server.listenHttp(*listener); + + uint count = 0; + uint cumulative = 0; + CountingNetworkAddress countingAddr(*addr, count, cumulative); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.idleTimeout = 0 * kj::SECONDS; + auto client = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 0); + + uint i = 0; + auto doRequest = [&]() { + uint n = i++; + return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response + .then([](HttpClient::Response&& response) { + auto promise = response.body->readAllText(); + return promise.attach(kj::mv(response.body)); + }).then([n](kj::String body) { + KJ_EXPECT(body == kj::str("null:/", n)); + }); + }; + + // Each serial request gets its own connection. + doRequest().wait(waitScope); + doRequest().wait(waitScope); + doRequest().wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 3); + + // Each parallel request gets its own connection. + auto req1 = doRequest(); + auto req2 = doRequest(); + req1.wait(waitScope); + req2.wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 5); +} + +KJ_TEST("HttpClient concurrency limiting") { +#if KJ_HTTP_TEST_USE_OS_PIPE && !__linux__ + // On Windows and Mac, OS event delivery is not always immediate, and that seems to make this + // test flakey. On Linux, events are always immediately delivered. For now, we compile the test + // but we don't run it outside of Linux. We do run the in-memory-pipes version on all OSs since + // that mode shouldn't depend on kernel behavior at all. + return; +#endif + + KJ_HTTP_TEST_SETUP_IO; + KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; + + kj::TimerImpl serverTimer(kj::origin()); + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + + DummyService service(headerTable); + HttpServerSettings serverSettings; + HttpServer server(serverTimer, headerTable, service, serverSettings); + auto listenTask = server.listenHttp(*listener); + + uint count = 0; + uint cumulative = 0; + CountingNetworkAddress countingAddr(*addr, count, cumulative); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.idleTimeout = 0 * kj::SECONDS; + auto innerClient = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + + struct CallbackEvent { + uint runningCount; + uint pendingCount; + + bool operator==(const CallbackEvent& other) const { + return runningCount == other.runningCount && pendingCount == other.pendingCount; + } + bool operator!=(const CallbackEvent& other) const { return !(*this == other); } + // TODO(someday): Can use default spaceship operator in C++20: + //auto operator<=>(const CallbackEvent&) const = default; + }; + + kj::Vector callbackEvents; + auto callback = [&](uint runningCount, uint pendingCount) { + callbackEvents.add(CallbackEvent{runningCount, pendingCount}); + }; + auto client = newConcurrencyLimitingHttpClient(*innerClient, 1, kj::mv(callback)); + + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 0); + + uint i = 0; + auto doRequest = [&]() { + uint n = i++; + return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response + .then([](HttpClient::Response&& response) { + auto promise = response.body->readAllText(); + return promise.attach(kj::mv(response.body)); + }).then([n](kj::String body) { + KJ_EXPECT(body == kj::str("null:/", n)); + }); + }; + + // Second connection blocked by first. + auto req1 = doRequest(); + + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); + callbackEvents.clear(); + + auto req2 = doRequest(); + + // TODO(someday): Figure out why this poll() is necessary on Windows and macOS. + waitScope.poll(); + + KJ_EXPECT(req1.poll(waitScope)); + KJ_EXPECT(!req2.poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 1); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 1} })); + callbackEvents.clear(); + + // Releasing first connection allows second to start. + req1.wait(waitScope); + KJ_EXPECT(req2.poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 2); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); + callbackEvents.clear(); + + req2.wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 2); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {0, 0} })); + callbackEvents.clear(); + + // Using body stream after releasing blocked response promise throws no exception + auto req3 = doRequest(); + { + kj::Own req4Body; + { + auto req4 = client->request(HttpMethod::GET, kj::str("/", ++i), HttpHeaders(headerTable)); + waitScope.poll(); + req4Body = kj::mv(req4.body); + } + auto writePromise = req4Body->write("a", 1); + KJ_EXPECT(!writePromise.poll(waitScope)); + } + req3.wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 3); + + // Similar connection limiting for web sockets + // TODO(someday): Figure out why the sequencing of websockets events does + // not work correctly on Windows (and maybe macOS?). The solution is not as + // simple as inserting poll()s as above, since doing so puts the websocket in + // a state that trips a "previous HTTP message body incomplete" assertion, + // while trying to write 500 network response. + callbackEvents.clear(); + auto ws1 = kj::heap(client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable))); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); + callbackEvents.clear(); + auto ws2 = kj::heap(client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable))); + KJ_EXPECT(ws1->poll(waitScope)); + KJ_EXPECT(!ws2->poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 4); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 1} })); + callbackEvents.clear(); + + { + auto response1 = ws1->wait(waitScope); + KJ_EXPECT(!ws2->poll(waitScope)); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({})); + } + KJ_EXPECT(ws2->poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 5); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); + callbackEvents.clear(); + { + auto response2 = ws2->wait(waitScope); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({})); + } + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 5); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {0, 0} })); +} + +#if KJ_HTTP_TEST_USE_OS_PIPE +// This test relies on access to the network. +KJ_TEST("NetworkHttpClient connect impl") { + KJ_HTTP_TEST_SETUP_IO; + auto listener1 = io.provider->getNetwork().parseAddress("localhost", 0) + .wait(io.waitScope)->listen(); + + auto ignored KJ_UNUSED = listener1->accept().then([](Own stream) { + auto buffer = kj::str("test"); + return stream->write(buffer.cStr(), buffer.size()).attach(kj::mv(stream), kj::mv(buffer)); + }).eagerlyEvaluate(nullptr); + + HttpClientSettings clientSettings; + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + auto client = newHttpClient(clientTimer, headerTable, + io.provider->getNetwork(), nullptr, clientSettings); + auto request = client->connect( + kj::str("localhost:", listener1->getPort()), HttpHeaders(headerTable), {}); + + auto buf = kj::heapArray(4); + return request.connection->tryRead(buf.begin(), 1, buf.size()) + .then([buf = kj::mv(buf)](size_t count) { + KJ_ASSERT(count == 4); + KJ_ASSERT(kj::str(buf.asChars()) == "test"); + }).attach(kj::mv(request.connection)).wait(io.waitScope); +} +#endif + +#if KJ_HTTP_TEST_USE_OS_PIPE +// TODO(someday): Implement mock kj::Network for userspace version of this test? +KJ_TEST("HttpClient multi host") { + auto io = kj::setupAsyncIo(); + + kj::TimerImpl serverTimer(kj::origin()); + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + + auto listener1 = io.provider->getNetwork().parseAddress("localhost", 0) + .wait(io.waitScope)->listen(); + auto listener2 = io.provider->getNetwork().parseAddress("localhost", 0) + .wait(io.waitScope)->listen(); + DummyService service(headerTable); + HttpServer server(serverTimer, headerTable, service); + auto listenTask1 = server.listenHttp(*listener1); + auto listenTask2 = server.listenHttp(*listener2); + + uint count = 0, addrCount = 0; + uint tlsCount = 0, tlsAddrCount = 0; + ConnectionCountingNetwork countingNetwork(io.provider->getNetwork(), count, addrCount); + ConnectionCountingNetwork countingTlsNetwork(io.provider->getNetwork(), tlsCount, tlsAddrCount); + + HttpClientSettings clientSettings; + auto client = newHttpClient(clientTimer, headerTable, + countingNetwork, countingTlsNetwork, clientSettings); + + KJ_EXPECT(count == 0); + + uint i = 0; + auto doRequest = [&](bool tls, uint port) { + uint n = i++; + // We stick a double-slash in the URL to test that it doesn't get coalesced into one slash, + // which was a bug in the past. + return client->request(HttpMethod::GET, + kj::str((tls ? "https://localhost:" : "http://localhost:"), port, "//", n), + HttpHeaders(headerTable)).response + .then([](HttpClient::Response&& response) { + auto promise = response.body->readAllText(); + return promise.attach(kj::mv(response.body)); + }).then([n, port](kj::String body) { + KJ_EXPECT(body == kj::str("localhost:", port, "://", n), body, port, n); + }); + }; + + uint port1 = listener1->getPort(); + uint port2 = listener2->getPort(); + + // We can do several requests in a row to the same host and only have one connection. + doRequest(false, port1).wait(io.waitScope); + doRequest(false, port1).wait(io.waitScope); + doRequest(false, port1).wait(io.waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(tlsCount == 0); + KJ_EXPECT(addrCount == 1); + KJ_EXPECT(tlsAddrCount == 0); + + // Request a different host, and now we have two connections. + doRequest(false, port2).wait(io.waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(tlsCount == 0); + KJ_EXPECT(addrCount == 2); + KJ_EXPECT(tlsAddrCount == 0); + + // Try TLS. + doRequest(true, port1).wait(io.waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(tlsCount == 1); + KJ_EXPECT(addrCount == 2); + KJ_EXPECT(tlsAddrCount == 1); + + // Try first host again, no change in connection count. + doRequest(false, port1).wait(io.waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(tlsCount == 1); + KJ_EXPECT(addrCount == 2); + KJ_EXPECT(tlsAddrCount == 1); + + // Multiple requests in parallel forces more connections to that host. + auto promise1 = doRequest(false, port1); + auto promise2 = doRequest(false, port1); + promise1.wait(io.waitScope); + promise2.wait(io.waitScope); + KJ_EXPECT(count == 3); + KJ_EXPECT(tlsCount == 1); + KJ_EXPECT(addrCount == 2); + KJ_EXPECT(tlsAddrCount == 1); + + // Let everything expire. + clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout * 2); + io.waitScope.poll(); + KJ_EXPECT(count == 0); + KJ_EXPECT(tlsCount == 0); + KJ_EXPECT(addrCount == 0); + KJ_EXPECT(tlsAddrCount == 0); + + // We can still request those hosts again. + doRequest(false, port1).wait(io.waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(tlsCount == 0); + KJ_EXPECT(addrCount == 1); + KJ_EXPECT(tlsAddrCount == 0); +} +#endif + +// ----------------------------------------------------------------------------- + +#if KJ_HTTP_TEST_USE_OS_PIPE +// This test only makes sense using the real network. +KJ_TEST("HttpClient to capnproto.org") { + auto io = kj::setupAsyncIo(); + + auto maybeConn = io.provider->getNetwork().parseAddress("capnproto.org", 80) + .then([](kj::Own addr) { + auto promise = addr->connect(); + return promise.attach(kj::mv(addr)); + }).then([](kj::Own&& connection) -> kj::Maybe> { + return kj::mv(connection); + }, [](kj::Exception&& e) -> kj::Maybe> { + KJ_LOG(WARNING, "skipping test because couldn't connect to capnproto.org"); + return nullptr; + }).wait(io.waitScope); + + KJ_IF_MAYBE(conn, maybeConn) { + // Successfully connected to capnproto.org. Try doing GET /. We expect to get a redirect to + // HTTPS, because what kind of horrible web site would serve in plaintext, really? + + HttpHeaderTable table; + auto client = newHttpClient(table, **conn); + + HttpHeaders headers(table); + headers.set(HttpHeaderId::HOST, "capnproto.org"); + + auto response = client->request(HttpMethod::GET, "/", headers).response.wait(io.waitScope); + KJ_EXPECT(response.statusCode / 100 == 3); + auto location = KJ_ASSERT_NONNULL(response.headers->get(HttpHeaderId::LOCATION)); + KJ_EXPECT(location == "https://capnproto.org/"); + + auto body = response.body->readAllText().wait(io.waitScope); + } +} +#endif + +// ======================================================================================= +// Misc bugfix tests + +class ReadCancelHttpService final: public HttpService { + // HttpService that tries to read all request data but cancels after 1ms and sends a response. +public: + ReadCancelHttpService(kj::Timer& timer, HttpHeaderTable& headerTable) + : timer(timer), headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& responseSender) override { + if (method == HttpMethod::POST) { + // Try to read all content, but cancel after 1ms. + + // Actually, we can't literally cancel mid-read, because this leaves the stream in an + // unknown state which requires closing the connection. Instead, we know that the sender + // will send 5 bytes, so we read that, then pause. + static char junk[5]; + return requestBody.read(junk, 5) + .then([]() -> kj::Promise { return kj::NEVER_DONE; }) + .exclusiveJoin(timer.afterDelay(1 * kj::MILLISECONDS)) + .then([this, &responseSender]() { + responseSender.send(408, "Request Timeout", kj::HttpHeaders(headerTable), uint64_t(0)); + }); + } else { + responseSender.send(200, "OK", kj::HttpHeaders(headerTable), uint64_t(0)); + return kj::READY_NOW; + } + } + +private: + kj::Timer& timer; + HttpHeaderTable& headerTable; +}; + +KJ_TEST("canceling a length stream mid-read correctly discards rest of request") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + ReadCancelHttpService service(timer, table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + { + static constexpr kj::StringPtr REQUEST = + "POST / HTTP/1.1\r\n" + "Content-Length: 6\r\n" + "\r\n" + "fooba"_kj; // incomplete + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + auto promise = expectRead(*pipe.ends[1], + "HTTP/1.1 408 Request Timeout\r\n" + "Content-Length: 0\r\n" + "\r\n"_kj); + + KJ_EXPECT(!promise.poll(waitScope)); + + // Trigger timeout, then response should be sent. + timer.advanceTo(timer.now() + 1 * kj::MILLISECONDS); + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } + + // We left our request stream hanging. The server will try to read and discard the request body. + // Let's give it the rest of the data, followed by a second request. + { + static constexpr kj::StringPtr REQUEST = + "r" + "GET / HTTP/1.1\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + auto promise = expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "\r\n"_kj); + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } +} + +KJ_TEST("canceling a chunked stream mid-read correctly discards rest of request") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + ReadCancelHttpService service(timer, table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + { + static constexpr kj::StringPtr REQUEST = + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "fooba"_kj; // incomplete chunk + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + auto promise = expectRead(*pipe.ends[1], + "HTTP/1.1 408 Request Timeout\r\n" + "Content-Length: 0\r\n" + "\r\n"_kj); + + KJ_EXPECT(!promise.poll(waitScope)); + + // Trigger timeout, then response should be sent. + timer.advanceTo(timer.now() + 1 * kj::MILLISECONDS); + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } + + // We left our request stream hanging. The server will try to read and discard the request body. + // Let's give it the rest of the data, followed by a second request. + { + static constexpr kj::StringPtr REQUEST = + "r\r\n" + "4a\r\n" + "this is some text that is the body of a chunk and not a valid chunk header\r\n" + "0\r\n" + "\r\n" + "GET / HTTP/1.1\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + auto promise = expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "\r\n"_kj); + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } +} + +KJ_TEST("drain() doesn't lose bytes when called at the wrong moment") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + DummyService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttpCleanDrain(*pipe.ends[0]); + + // Do a regular request. + static constexpr kj::StringPtr REQUEST = + "GET / HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 13\r\n" + "\r\n" + "example.com:/"_kj).wait(waitScope); + + // Make sure the server is blocked on the next read from the socket. + kj::Promise(kj::NEVER_DONE).poll(waitScope); + + // Now simultaneously deliver a new request AND drain the socket. + auto drainPromise = server.drain(); + static constexpr kj::StringPtr REQUEST2 = + "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST2.begin(), REQUEST2.size()).wait(waitScope); + +#if KJ_HTTP_TEST_USE_OS_PIPE + // In the case of an OS pipe, the drain will complete before any data is read from the socket. + drainPromise.wait(waitScope); + + // The HTTP server should indicate the connection was released but still valid. + KJ_ASSERT(listenTask.wait(waitScope)); + + // The request will not have been read off the socket. We can read it now. + pipe.ends[1]->shutdownWrite(); + KJ_EXPECT(pipe.ends[0]->readAllText().wait(waitScope) == REQUEST2); + +#else + // In the case of an in-memory pipe, the write() will have delivered bytes directly to the + // destination buffer synchronously, which means that the server must handle the request + // before draining. + KJ_EXPECT(!drainPromise.poll(waitScope)); + + // The HTTP request should get a response. + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 16\r\n" + "\r\n" + "example.com:/foo"_kj).wait(waitScope); + + // Now the drain completes. + drainPromise.wait(waitScope); + + // The HTTP server should indicate the connection was released but still valid. + KJ_ASSERT(listenTask.wait(waitScope)); +#endif +} + +KJ_TEST("drain() does not cancel the first request on a new connection") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + DummyService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttpCleanDrain(*pipe.ends[0]); + + // Request a drain(). It won't complete, because the newly-connected socket is considered to have + // an in-flight request. + auto drainPromise = server.drain(); + KJ_EXPECT(!drainPromise.poll(waitScope)); + + // Deliver the request. + static constexpr kj::StringPtr REQUEST2 = + "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST2.begin(), REQUEST2.size()).wait(waitScope); + + // It should get a response. + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 16\r\n" + "\r\n" + "example.com:/foo"_kj).wait(waitScope); + + // Now the drain completes. + drainPromise.wait(waitScope); + + // The HTTP server should indicate the connection was released but still valid. + KJ_ASSERT(listenTask.wait(waitScope)); +} + +KJ_TEST("drain() when NOT using listenHttpCleanDrain() sends Connection: close header") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + DummyService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Request a drain(). It won't complete, because the newly-connected socket is considered to have + // an in-flight request. + auto drainPromise = server.drain(); + KJ_EXPECT(!drainPromise.poll(waitScope)); + + // Deliver the request. + static constexpr kj::StringPtr REQUEST2 = + "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST2.begin(), REQUEST2.size()).wait(waitScope); + + // It should get a response. + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Connection: close\r\n" + "Content-Length: 16\r\n" + "\r\n" + "example.com:/foo"_kj).wait(waitScope); + + // And then EOF. + auto rest = pipe.ends[1]->readAllText(); + KJ_ASSERT(rest.poll(waitScope)); + KJ_EXPECT(rest.wait(waitScope) == nullptr); + + // The drain task and listen task are done. + drainPromise.wait(waitScope); + listenTask.wait(waitScope); +} + +class BrokenConnectionListener final: public kj::ConnectionReceiver { +public: + void fulfillOne(kj::Own stream) { + fulfiller->fulfill(kj::mv(stream)); + } + + kj::Promise> accept() override { + auto paf = kj::newPromiseAndFulfiller>(); + fulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + + uint getPort() override { + KJ_UNIMPLEMENTED("not used"); + } + +private: + kj::Own>> fulfiller; +}; + +class BrokenConnection final: public kj::AsyncIoStream { +public: + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return KJ_EXCEPTION(FAILED, "broken"); + } + Promise write(const void* buffer, size_t size) override { + return KJ_EXCEPTION(FAILED, "broken"); + } + Promise write(ArrayPtr> pieces) override { + return KJ_EXCEPTION(FAILED, "broken"); + } + Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + void shutdownWrite() override {} +}; + +KJ_TEST("HttpServer.listenHttp() doesn't prematurely terminate if an accepted connection is broken") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + DummyService service(table); + HttpServer server(timer, table, service); + + BrokenConnectionListener listener; + auto promise = server.listenHttp(listener).eagerlyEvaluate(nullptr); + + // Loop is waiting for a connection. + KJ_ASSERT(!promise.poll(waitScope)); + + KJ_EXPECT_LOG(ERROR, "failed: broken"); + listener.fulfillOne(kj::heap()); + + // The loop should not have stopped, even though the connection was broken. + KJ_ASSERT(!promise.poll(waitScope)); +} + +KJ_TEST("HttpServer handles disconnected exception for clients disconnecting after headers") { + // This test case reproduces a race condition where a client could disconnect after the server + // sent response headers but before it sent the response body, resulting in a broken pipe + // "disconnected" exception when writing the body. The default handler for application errors + // tells the server to ignore "disconnected" exceptions and close the connection, but code + // after the handler exercised the broken connection, causing the server loop to instead fail + // with a "failed" exception. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + class SendErrorHttpService final: public HttpService { + // HttpService that serves an error page via sendError(). + public: + SendErrorHttpService(HttpHeaderTable& headerTable): headerTable(headerTable) {} + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& responseSender) override { + return responseSender.sendError(404, "Not Found", headerTable); + } + + private: + HttpHeaderTable& headerTable; + }; + + class DisconnectingAsyncIoStream final: public kj::AsyncIoStream { + public: + DisconnectingAsyncIoStream(AsyncIoStream& inner): inner(inner) {} + + Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner.read(buffer, minBytes, maxBytes); + } + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner.tryRead(buffer, minBytes, maxBytes); + } + + Maybe tryGetLength() override { return inner.tryGetLength(); } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return inner.pumpTo(output, amount); + } + + Promise write(const void* buffer, size_t size) override { + int writeId = writeCount++; + if (writeId == 0) { + // Allow first write (headers) to succeed. + auto promise = inner.write(buffer, size); + inner.shutdownWrite(); + return promise; + } else if (writeId == 1) { + // Fail subsequent write (body) with a disconnected exception. + return KJ_EXCEPTION(DISCONNECTED, "a_disconnected_exception"); + } else { + KJ_FAIL_ASSERT("Unexpected write"); + } + } + Promise write(ArrayPtr> pieces) override { + return inner.write(pieces); + } + + Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { + return inner.tryPumpFrom(input, amount); + } + + Promise whenWriteDisconnected() override { + return inner.whenWriteDisconnected(); + } + + void shutdownWrite() override { + return inner.shutdownWrite(); + } + + void abortRead() override { + return inner.abortRead(); + } + + void getsockopt(int level, int option, void* value, uint* length) override { + return inner.getsockopt(level, option, value, length); + } + void setsockopt(int level, int option, const void* value, uint length) override { + return inner.setsockopt(level, option, value, length); + } + + void getsockname(struct sockaddr* addr, uint* length) override { + return inner.getsockname(addr, length); + } + void getpeername(struct sockaddr* addr, uint* length) override { + return inner.getsockname(addr, length); + } + + int writeCount = 0; + + private: + kj::AsyncIoStream& inner; + }; + + class TestErrorHandler: public HttpServerErrorHandler { + public: + kj::Promise handleApplicationError( + kj::Exception exception, kj::Maybe response) override { + applicationErrorCount++; + if (exception.getType() == kj::Exception::Type::DISCONNECTED) { + // Tell HttpServer to ignore disconnected exceptions (the default behavior). + return kj::READY_NOW; + } + KJ_FAIL_ASSERT("Unexpected application error type", exception.getType()); + } + + int applicationErrorCount = 0; + }; + + TestErrorHandler testErrorHandler; + HttpServerSettings settings {}; + settings.errorHandler = testErrorHandler; + + HttpHeaderTable table; + SendErrorHttpService service(table); + HttpServer server(timer, table, service, settings); + + auto stream = kj::heap(*pipe.ends[0]); + auto listenPromise = server.listenHttpCleanDrain(*stream); + + static constexpr auto request = "GET / HTTP/1.1\r\n\r\n"_kj; + pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + // Client races to read headers but not body, then disconnects. (Note that the following code + // doesn't reliably reproduce the race condition by itself -- DisconnectingAsyncIoStream is + // needed to ensure the disconnected exception throws on the correct write promise.) + expectRead(*pipe.ends[1], + "HTTP/1.1 404 Not Found\r\n" + "Content-Length: 9\r\n" + "\r\n"_kj).wait(waitScope); + pipe.ends[1] = nullptr; + + // The race condition failure would manifest as a "previous HTTP message body incomplete" + // "FAILED" exception here: + bool canReuse = listenPromise.wait(waitScope); + + KJ_ASSERT(!canReuse); + KJ_ASSERT(stream->writeCount == 2); + KJ_ASSERT(testErrorHandler.applicationErrorCount == 1); +} + +// ======================================================================================= +// CONNECT tests + +class ConnectEchoService final: public HttpService { + // A simple CONNECT echo. It will always accept, and whatever data it + // receives will be echoed back. +public: + ConnectEchoService(HttpHeaderTable& headerTable, uint statusCodeToSend = 200) + : headerTable(headerTable), + statusCodeToSend(statusCodeToSend) { + KJ_ASSERT(statusCodeToSend >= 200 && statusCodeToSend < 300); + } + + uint connectCount = 0; + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + connectCount++; + response.accept(statusCodeToSend, "OK", HttpHeaders(headerTable)); + return connection.pumpTo(connection).ignoreResult(); + } + +private: + HttpHeaderTable& headerTable; + uint statusCodeToSend; +}; + +class ConnectRejectService final: public HttpService { + // A simple CONNECT implementation that always rejects. +public: + ConnectRejectService(HttpHeaderTable& headerTable, uint statusCodeToSend = 400) + : headerTable(headerTable), + statusCodeToSend(statusCodeToSend) { + KJ_ASSERT(statusCodeToSend >= 300); + } + + uint connectCount = 0; + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + connectCount++; + auto out = response.reject(statusCodeToSend, "Failed"_kj, HttpHeaders(headerTable), 4); + return out->write("boom", 4).attach(kj::mv(out)); + } + +private: + HttpHeaderTable& headerTable; + uint statusCodeToSend; +}; + +class ConnectCancelReadService final: public HttpService { + // A simple CONNECT server that will accept a connection then immediately + // cancel reading from it to test handling of abrupt termination. +public: + ConnectCancelReadService(HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", HttpHeaders(headerTable)); + // Return an immediately resolved promise and drop the connection + return kj::READY_NOW; + } + +private: + HttpHeaderTable& headerTable; +}; + +class ConnectCancelWriteService final: public HttpService { + // A simple CONNECT server that will accept a connection then immediately + // cancel writing to it to test handling of abrupt termination. +public: + ConnectCancelWriteService(HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", HttpHeaders(headerTable)); + + auto msg = "hello"_kj; + auto promise KJ_UNUSED = connection.write(msg.begin(), 5); + + // Return an immediately resolved promise and drop the io + return kj::READY_NOW; + } + +private: + HttpHeaderTable& headerTable; +}; + +class ConnectHttpService final: public HttpService { + // A CONNECT service that tunnels HTTP requests just to verify that, yes, the CONNECT + // impl can actually tunnel actual protocols. +public: + ConnectHttpService(HttpHeaderTable& table) + : timer(kj::origin()), + tunneledService(table), + server(timer, table, tunneledService) {} +private: + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", HttpHeaders(tunneledService.table)); + return server.listenHttp(kj::Own(&connection, kj::NullDisposer::instance)); + } + + class SimpleHttpService final: public HttpService { + public: + SimpleHttpService(HttpHeaderTable& table) : table(table) {} + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + auto out = response.send(200, "OK"_kj, HttpHeaders(table)); + auto msg = "hello there"_kj; + return out->write(msg.begin(), 11).attach(kj::mv(out)); + } + + HttpHeaderTable& table; + }; + + kj::TimerImpl timer; + SimpleHttpService tunneledService; + HttpServer server; +}; + +class ConnectCloseService final: public HttpService { + // A simple CONNECT server that will accept a connection then immediately + // shutdown the write side of the AsyncIoStream to simulate socket disconnection. +public: + ConnectCloseService(HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", HttpHeaders(headerTable)); + connection.shutdownWrite(); + return kj::READY_NOW; + } + +private: + HttpHeaderTable& headerTable; +}; + +KJ_TEST("Simple CONNECT Server works") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectEchoService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "\r\n" + "hello"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +KJ_TEST("Simple CONNECT Client/Server works") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectEchoService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable clientHeaders; + // Initiates a CONNECT with the echo server. Once established, sends a bit of data + // and waits for it to be echoed back. + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); + + request.status.then([io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(io->write("hello", 5)); + promises.add(expectRead(*io, "hello"_kj)); + return kj::joinPromises(promises.finish()) + .then([io=kj::mv(io)]() mutable { + io->shutdownWrite(); + return expectEnd(*io).attach(kj::mv(io)); + }); + }).wait(waitScope); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +KJ_TEST("CONNECT Server (201 status)") { + KJ_HTTP_TEST_SETUP_IO; + + // Test that CONNECT works with 2xx status codes that typically do + // not carry a response payload. + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectEchoService service(table, 201); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 201 OK\r\n" + "\r\n" + "hello"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +KJ_TEST("CONNECT Client (204 status)") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + // Test that CONNECT works with 2xx status codes that typically do + // not carry a response payload. + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectEchoService service(table, 204); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable clientHeaders; + // Initiates a CONNECT with the echo server. Once established, sends a bit of data + // and waits for it to be echoed back. + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); + + request.status.then([io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 204); + KJ_ASSERT(status.statusText == "OK"_kj); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(io->write("hello", 5)); + promises.add(expectRead(*io, "hello"_kj)); + + return kj::joinPromises(promises.finish()) + .then([io=kj::mv(io)]() mutable { + io->shutdownWrite(); + return expectEnd(*io).attach(kj::mv(io)); + }); + }).wait(waitScope); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +KJ_TEST("CONNECT Server rejected") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectRejectService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 400 Failed\r\n" + "Connection: close\r\n" + "Content-Length: 4\r\n" + "\r\n" + "boom"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +#ifndef KJ_HTTP_TEST_USE_OS_PIPE +KJ_TEST("CONNECT Client rejected") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectRejectService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable clientHeaders; + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); + + request.status.then([](auto status) mutable { + KJ_ASSERT(status.statusCode == 400); + KJ_ASSERT(status.statusText == "Failed"_kj); + + auto& errorBody = KJ_ASSERT_NONNULL(status.errorBody); + + return expectRead(*errorBody, "boom"_kj).then([&errorBody=*errorBody]() { + return expectEnd(errorBody); + }).attach(kj::mv(errorBody)); + }).wait(waitScope); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} +#endif + +KJ_TEST("CONNECT Server cancels read") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectCancelReadService service(table); + HttpServer server(timer, table, service); - // Connections where we failed to write the full request body are not reused. - doRequest().wait(waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 7); - client->request(HttpMethod::POST, kj::str("/foo"), HttpHeaders(headerTable), size_t(6)).response - .wait(waitScope).body->readAllBytes().wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 7); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - // If the server times out the connection, we figure it out on the client. - doRequest().wait(waitScope); + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; - // TODO(someday): Figure out why the following poll is necessary for the test to pass on Windows - // and Mac. Without it, it seems that the request's connection never starts, so the - // subsequent advanceTo() does not actually time out the connection. - waitScope.poll(); + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 8); - serverTimer.advanceTo(serverTimer.now() + serverSettings.pipelineTimeout * 2); - waitScope.poll(); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 8); + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "\r\n"_kj).wait(waitScope); - // Can still make requests. - doRequest().wait(waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 9); + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); } -KJ_TEST("HttpClient disable connection reuse") { +#ifndef KJ_HTTP_TEST_USE_OS_PIPE +KJ_TEST("CONNECT Server cancels read w/client") { KJ_HTTP_TEST_SETUP_IO; - KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; - kj::TimerImpl serverTimer(kj::origin()); - kj::TimerImpl clientTimer(kj::origin()); - HttpHeaderTable headerTable; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - DummyService service(headerTable); - HttpServerSettings serverSettings; - HttpServer server(serverTimer, headerTable, service, serverSettings); - auto listenTask = server.listenHttp(*listener); + kj::TimerImpl timer(kj::origin()); - uint count = 0; - uint cumulative = 0; - CountingNetworkAddress countingAddr(*addr, count, cumulative); + HttpHeaderTable table; + ConnectCancelReadService service(table); + HttpServer server(timer, table, service); - FakeEntropySource entropySource; - HttpClientSettings clientSettings; - clientSettings.entropySource = entropySource; - clientSettings.idleTimeout = 0 * kj::SECONDS; - auto client = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 0); + auto client = newHttpClient(table, *pipe.ends[1]); + bool failed = false; - uint i = 0; - auto doRequest = [&]() { - uint n = i++; - return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response - .then([](HttpClient::Response&& response) { - auto promise = response.body->readAllText(); - return promise.attach(kj::mv(response.body)); - }).then([n](kj::String body) { - KJ_EXPECT(body == kj::str("null:/", n)); - }); - }; + HttpHeaderTable clientHeaders; + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); - // Each serial request gets its own connection. - doRequest().wait(waitScope); - doRequest().wait(waitScope); - doRequest().wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 3); + request.status.then([&failed, io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); - // Each parallel request gets its own connection. - auto req1 = doRequest(); - auto req2 = doRequest(); - req1.wait(waitScope); - req2.wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 5); -} + return io->write("hello", 5).catch_([&](kj::Exception&& ex) { + KJ_ASSERT(ex.getType() == kj::Exception::Type::DISCONNECTED); + failed = true; + }).attach(kj::mv(io)); + }).wait(waitScope); -KJ_TEST("HttpClient concurrency limiting") { -#if KJ_HTTP_TEST_USE_OS_PIPE && !__linux__ - // On Windows and Mac, OS event delivery is not always immediate, and that seems to make this - // test flakey. On Linux, events are always immediately delivered. For now, we compile the test - // but we don't run it outside of Linux. We do run the in-memory-pipes version on all OSs since - // that mode shouldn't depend on kernel behavior at all. - return; + KJ_ASSERT(failed, "the write promise should have failed"); + + listenTask.wait(waitScope); +} #endif +KJ_TEST("CONNECT Server cancels write") { KJ_HTTP_TEST_SETUP_IO; - KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; - kj::TimerImpl serverTimer(kj::origin()); - kj::TimerImpl clientTimer(kj::origin()); - HttpHeaderTable headerTable; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - DummyService service(headerTable); - HttpServerSettings serverSettings; - HttpServer server(serverTimer, headerTable, service, serverSettings); - auto listenTask = server.listenHttp(*listener); + kj::TimerImpl timer(kj::origin()); - uint count = 0; - uint cumulative = 0; - CountingNetworkAddress countingAddr(*addr, count, cumulative); + HttpHeaderTable table; + ConnectCancelWriteService service(table); + HttpServer server(timer, table, service); - FakeEntropySource entropySource; - HttpClientSettings clientSettings; - clientSettings.entropySource = entropySource; - clientSettings.idleTimeout = 0 * kj::SECONDS; - auto innerClient = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - struct CallbackEvent { - uint runningCount; - uint pendingCount; + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; - bool operator==(const CallbackEvent& other) const { - return runningCount == other.runningCount && pendingCount == other.pendingCount; - } - bool operator!=(const CallbackEvent& other) const { return !(*this == other); } - // TODO(someday): Can use default spaceship operator in C++20: - //auto operator<=>(const CallbackEvent&) const = default; - }; + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); - kj::Vector callbackEvents; - auto callback = [&](uint runningCount, uint pendingCount) { - callbackEvents.add(CallbackEvent{runningCount, pendingCount}); - }; - auto client = newConcurrencyLimitingHttpClient(*innerClient, 1, kj::mv(callback)); + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "\r\n"_kj).wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 0); + expectEnd(*pipe.ends[1]); - uint i = 0; - auto doRequest = [&]() { - uint n = i++; - return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response - .then([](HttpClient::Response&& response) { - auto promise = response.body->readAllText(); - return promise.attach(kj::mv(response.body)); - }).then([n](kj::String body) { - KJ_EXPECT(body == kj::str("null:/", n)); - }); - }; + listenTask.wait(waitScope); +} - // Second connection blocked by first. - auto req1 = doRequest(); +#ifndef KJ_HTTP_TEST_USE_OS_PIPE +KJ_TEST("CONNECT Server cancels write w/client") { + KJ_HTTP_TEST_SETUP_IO; - KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); - callbackEvents.clear(); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto req2 = doRequest(); + kj::TimerImpl timer(kj::origin()); - // TODO(someday): Figure out why this poll() is necessary on Windows and macOS. - waitScope.poll(); + HttpHeaderTable table; + ConnectCancelWriteService service(table); + HttpServer server(timer, table, service); - KJ_EXPECT(req1.poll(waitScope)); - KJ_EXPECT(!req2.poll(waitScope)); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 1); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 1} })); - callbackEvents.clear(); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - // Releasing first connection allows second to start. - req1.wait(waitScope); - KJ_EXPECT(req2.poll(waitScope)); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 2); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); - callbackEvents.clear(); + auto client = newHttpClient(table, *pipe.ends[1]); - req2.wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 2); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {0, 0} })); - callbackEvents.clear(); + HttpHeaderTable clientHeaders; + bool failed = false; + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); - // Using body stream after releasing blocked response promise throws no exception - auto req3 = doRequest(); - { - kj::Own req4Body; - { - auto req4 = client->request(HttpMethod::GET, kj::str("/", ++i), HttpHeaders(headerTable)); - waitScope.poll(); - req4Body = kj::mv(req4.body); - } - auto writePromise = req4Body->write("a", 1); - KJ_EXPECT(!writePromise.poll(waitScope)); - } - req3.wait(waitScope); - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 3); + request.status.then([&failed, io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); - // Similar connection limiting for web sockets - // TODO(someday): Figure out why the sequencing of websockets events does - // not work correctly on Windows (and maybe macOS?). The solution is not as - // simple as inserting poll()s as above, since doing so puts the websocket in - // a state that trips a "previous HTTP message body incomplete" assertion, - // while trying to write 500 network response. - callbackEvents.clear(); - auto ws1 = kj::heap(client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable))); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); - callbackEvents.clear(); - auto ws2 = kj::heap(client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable))); - KJ_EXPECT(ws1->poll(waitScope)); - KJ_EXPECT(!ws2->poll(waitScope)); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 4); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 1} })); - callbackEvents.clear(); + return io->write("hello", 5).catch_([&failed](kj::Exception&& ex) mutable { + KJ_ASSERT(ex.getType() == kj::Exception::Type::DISCONNECTED); + failed = true; + }).attach(kj::mv(io)); + }).wait(waitScope); - { - auto response1 = ws1->wait(waitScope); - KJ_EXPECT(!ws2->poll(waitScope)); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({})); - } - KJ_EXPECT(ws2->poll(waitScope)); - KJ_EXPECT(count == 1); - KJ_EXPECT(cumulative == 5); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); - callbackEvents.clear(); - { - auto response2 = ws2->wait(waitScope); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({})); - } - KJ_EXPECT(count == 0); - KJ_EXPECT(cumulative == 5); - KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {0, 0} })); + KJ_ASSERT(failed, "the write promise should have failed"); + + listenTask.wait(waitScope); } +#endif -#if KJ_HTTP_TEST_USE_OS_PIPE -// TODO(someday): Implement mock kj::Network for userspace version of this test? -KJ_TEST("HttpClient multi host") { - auto io = kj::setupAsyncIo(); +KJ_TEST("CONNECT rejects Transfer-Encoding") { + KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl serverTimer(kj::origin()); - kj::TimerImpl clientTimer(kj::origin()); - HttpHeaderTable headerTable; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto listener1 = io.provider->getNetwork().parseAddress("localhost", 0) - .wait(io.waitScope)->listen(); - auto listener2 = io.provider->getNetwork().parseAddress("localhost", 0) - .wait(io.waitScope)->listen(); - DummyService service(headerTable); - HttpServer server(serverTimer, headerTable, service); - auto listenTask1 = server.listenHttp(*listener1); - auto listenTask2 = server.listenHttp(*listener2); + kj::TimerImpl timer(kj::origin()); - uint count = 0, addrCount = 0; - uint tlsCount = 0, tlsAddrCount = 0; - ConnectionCountingNetwork countingNetwork(io.provider->getNetwork(), count, addrCount); - ConnectionCountingNetwork countingTlsNetwork(io.provider->getNetwork(), tlsCount, tlsAddrCount); + HttpHeaderTable table; + ConnectEchoService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "5\r\n" + "hello" + "0\r\n"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 18\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Bad Request"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); - HttpClientSettings clientSettings; - auto client = newHttpClient(clientTimer, headerTable, - countingNetwork, countingTlsNetwork, clientSettings); + listenTask.wait(waitScope); +} - KJ_EXPECT(count == 0); +KJ_TEST("CONNECT rejects Content-Length") { + KJ_HTTP_TEST_SETUP_IO; - uint i = 0; - auto doRequest = [&](bool tls, uint port) { - uint n = i++; - // We stick a double-slash in the URL to test that it doesn't get coalesced into one slash, - // which was a bug in the past. - return client->request(HttpMethod::GET, - kj::str((tls ? "https://localhost:" : "http://localhost:"), port, "//", n), - HttpHeaders(headerTable)).response - .then([](HttpClient::Response&& response) { - auto promise = response.body->readAllText(); - return promise.attach(kj::mv(response.body)); - }).then([n, port](kj::String body) { - KJ_EXPECT(body == kj::str("localhost:", port, "://", n), body, port, n); - }); - }; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - uint port1 = listener1->getPort(); - uint port2 = listener2->getPort(); + kj::TimerImpl timer(kj::origin()); - // We can do several requests in a row to the same host and only have one connection. - doRequest(false, port1).wait(io.waitScope); - doRequest(false, port1).wait(io.waitScope); - doRequest(false, port1).wait(io.waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(tlsCount == 0); - KJ_EXPECT(addrCount == 1); - KJ_EXPECT(tlsAddrCount == 0); + HttpHeaderTable table; + ConnectEchoService service(table); + HttpServer server(timer, table, service); - // Request a different host, and now we have two connections. - doRequest(false, port2).wait(io.waitScope); - KJ_EXPECT(count == 2); - KJ_EXPECT(tlsCount == 0); - KJ_EXPECT(addrCount == 2); - KJ_EXPECT(tlsAddrCount == 0); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - // Try TLS. - doRequest(true, port1).wait(io.waitScope); - KJ_EXPECT(count == 2); - KJ_EXPECT(tlsCount == 1); - KJ_EXPECT(addrCount == 2); - KJ_EXPECT(tlsAddrCount == 1); + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "Content-Length: 5\r\n" + "\r\n" + "hello"_kj; - // Try first host again, no change in connection count. - doRequest(false, port1).wait(io.waitScope); - KJ_EXPECT(count == 2); - KJ_EXPECT(tlsCount == 1); - KJ_EXPECT(addrCount == 2); - KJ_EXPECT(tlsAddrCount == 1); + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); - // Multiple requests in parallel forces more connections to that host. - auto promise1 = doRequest(false, port1); - auto promise2 = doRequest(false, port1); - promise1.wait(io.waitScope); - promise2.wait(io.waitScope); - KJ_EXPECT(count == 3); - KJ_EXPECT(tlsCount == 1); - KJ_EXPECT(addrCount == 2); - KJ_EXPECT(tlsAddrCount == 1); + expectRead(*pipe.ends[1], + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 18\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Bad Request"_kj).wait(waitScope); - // Let everything expire. - clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout * 2); - io.waitScope.poll(); - KJ_EXPECT(count == 0); - KJ_EXPECT(tlsCount == 0); - KJ_EXPECT(addrCount == 0); - KJ_EXPECT(tlsAddrCount == 0); + expectEnd(*pipe.ends[1]); - // We can still request those hosts again. - doRequest(false, port1).wait(io.waitScope); - KJ_EXPECT(count == 1); - KJ_EXPECT(tlsCount == 0); - KJ_EXPECT(addrCount == 1); - KJ_EXPECT(tlsAddrCount == 0); + listenTask.wait(waitScope); } -#endif -// ----------------------------------------------------------------------------- +KJ_TEST("CONNECT HTTP-tunneled-over-CONNECT") { + KJ_HTTP_TEST_SETUP_IO; -#if KJ_HTTP_TEST_USE_OS_PIPE -// This test only makes sense using the real network. -KJ_TEST("HttpClient to capnproto.org") { - auto io = kj::setupAsyncIo(); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto maybeConn = io.provider->getNetwork().parseAddress("capnproto.org", 80) - .then([](kj::Own addr) { - auto promise = addr->connect(); - return promise.attach(kj::mv(addr)); - }).then([](kj::Own&& connection) -> kj::Maybe> { - return kj::mv(connection); - }, [](kj::Exception&& e) -> kj::Maybe> { - KJ_LOG(WARNING, "skipping test because couldn't connect to capnproto.org"); - return nullptr; - }).wait(io.waitScope); + kj::TimerImpl timer(kj::origin()); + HttpHeaderTable table; + ConnectHttpService service(table); + HttpServer server(timer, table, service); - KJ_IF_MAYBE(conn, maybeConn) { - // Successfully connected to capnproto.org. Try doing GET /. We expect to get a redirect to - // HTTPS, because what kind of horrible web site would serve in plaintext, really? + auto listenTask KJ_UNUSED = server.listenHttp(kj::mv(pipe.ends[0])); - HttpHeaderTable table; - auto client = newHttpClient(table, **conn); + auto client = newHttpClient(table, *pipe.ends[1]); - HttpHeaders headers(table); - headers.set(HttpHeaderId::HOST, "capnproto.org"); + HttpHeaderTable connectHeaderTable; + HttpHeaderTable tunneledHeaderTable; + HttpClientSettings settings; - auto response = client->request(HttpMethod::GET, "/", headers).response.wait(io.waitScope); - KJ_EXPECT(response.statusCode / 100 == 3); - auto location = KJ_ASSERT_NONNULL(response.headers->get(HttpHeaderId::LOCATION)); - KJ_EXPECT(location == "https://capnproto.org/"); + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(connectHeaderTable), {}); - auto body = response.body->readAllText().wait(io.waitScope); - } + auto text = request.status.then([ + &tunneledHeaderTable, + &settings, + io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + auto client = newHttpClient(tunneledHeaderTable, *io, settings) + .attach(kj::mv(io)); + + return client->request(HttpMethod::GET, "http://example.org"_kj, + HttpHeaders(tunneledHeaderTable)) + .response.then([](HttpClient::Response&& response) { + return response.body->readAllText().attach(kj::mv(response)); + }).attach(kj::mv(client)); + }).wait(waitScope); + + KJ_ASSERT(text == "hello there"); } -#endif -// ======================================================================================= -// Misc bugfix tests +KJ_TEST("CONNECT HTTP-tunneled-over-pipelined-CONNECT") { + KJ_HTTP_TEST_SETUP_IO; -class ReadCancelHttpService final: public HttpService { - // HttpService that tries to read all request data but cancels after 1ms and sends a response. -public: - ReadCancelHttpService(kj::Timer& timer, HttpHeaderTable& headerTable) - : timer(timer), headerTable(headerTable) {} + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - kj::Promise request( - HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::AsyncInputStream& requestBody, Response& responseSender) override { - if (method == HttpMethod::POST) { - // Try to read all content, but cancel after 1ms. - return requestBody.readAllBytes().ignoreResult() - .exclusiveJoin(timer.afterDelay(1 * kj::MILLISECONDS)) - .then([this, &responseSender]() { - responseSender.send(408, "Request Timeout", kj::HttpHeaders(headerTable), uint64_t(0)); - }); - } else { - responseSender.send(200, "OK", kj::HttpHeaders(headerTable), uint64_t(0)); - return kj::READY_NOW; - } - } + kj::TimerImpl timer(kj::origin()); + HttpHeaderTable table; + ConnectHttpService service(table); + HttpServer server(timer, table, service); -private: - kj::Timer& timer; - HttpHeaderTable& headerTable; -}; + auto listenTask KJ_UNUSED = server.listenHttp(kj::mv(pipe.ends[0])); -KJ_TEST("canceling a length stream mid-read correctly discards rest of request") { + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable connectHeaderTable; + HttpHeaderTable tunneledHeaderTable; + HttpClientSettings settings; + + auto request = client->connect( + "https://exmaple.org"_kj, HttpHeaders(connectHeaderTable), {}); + auto conn = kj::mv(request.connection); + auto proxyClient = newHttpClient(tunneledHeaderTable, *conn, settings).attach(kj::mv(conn)); + + auto get = proxyClient->request(HttpMethod::GET, + "http://example.org"_kj, + HttpHeaders(tunneledHeaderTable)); + auto text = get.response.then([](HttpClient::Response&& response) mutable { + return response.body->readAllText().attach(kj::mv(response)); + }).attach(kj::mv(proxyClient)).wait(waitScope); + + KJ_ASSERT(text == "hello there"); +} + +KJ_TEST("CONNECT pipelined via an adapter") { KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + kj::TimerImpl timer(kj::origin()); HttpHeaderTable table; - ReadCancelHttpService service(timer, table); + ConnectHttpService service(table); HttpServer server(timer, table, service); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto listenTask KJ_UNUSED = server.listenHttp(kj::mv(pipe.ends[0])); - { - static constexpr kj::StringPtr REQUEST = - "POST / HTTP/1.1\r\n" - "Content-Length: 6\r\n" - "\r\n" - "fooba"_kj; // incomplete - pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + bool acceptCalled = false; - auto promise = expectRead(*pipe.ends[1], - "HTTP/1.1 408 Request Timeout\r\n" - "Content-Length: 0\r\n" - "\r\n"_kj); + auto client = newHttpClient(table, *pipe.ends[1]); + auto adaptedService = kj::newHttpService(*client).attach(kj::mv(client)); - KJ_EXPECT(!promise.poll(waitScope)); + // adaptedService is an HttpService that wraps an HttpClient that sends + // a request to server. - // Trigger timout, then response should be sent. - timer.advanceTo(timer.now() + 1 * kj::MILLISECONDS); - KJ_ASSERT(promise.poll(waitScope)); - promise.wait(waitScope); - } + auto clientPipe = newTwoWayPipe(); - // We left our request stream hanging. The server will try to read and discard the request body. - // Let's give it the rest of the data, followed by a second request. - { - static constexpr kj::StringPtr REQUEST = - "r" - "GET / HTTP/1.1\r\n" - "\r\n"_kj; - pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + struct ResponseImpl final: public HttpService::ConnectResponse { + bool& acceptCalled; + ResponseImpl(bool& acceptCalled) : acceptCalled(acceptCalled) {} + void accept(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) override { + acceptCalled = true; + } - auto promise = expectRead(*pipe.ends[1], - "HTTP/1.1 200 OK\r\n" - "Content-Length: 0\r\n" - "\r\n"_kj); - KJ_ASSERT(promise.poll(waitScope)); - promise.wait(waitScope); - } + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + KJ_UNREACHABLE; + } + }; + + ResponseImpl response(acceptCalled); + + HttpHeaderTable connectHeaderTable; + HttpHeaderTable tunneledHeaderTable; + HttpClientSettings settings; + + auto promise = adaptedService->connect("https://example.org"_kj, + HttpHeaders(connectHeaderTable), + *clientPipe.ends[0], + response, + {}).attach(kj::mv(clientPipe.ends[0])); + + auto proxyClient = newHttpClient(tunneledHeaderTable, *clientPipe.ends[1], settings) + .attach(kj::mv(clientPipe.ends[1])); + + auto text = proxyClient->request(HttpMethod::GET, + "http://example.org"_kj, + HttpHeaders(tunneledHeaderTable)) + .response.then([](HttpClient::Response&& response) mutable { + return response.body->readAllText().attach(kj::mv(response)); + }).wait(waitScope); + + KJ_ASSERT(acceptCalled); + KJ_ASSERT(text == "hello there"); } -KJ_TEST("canceling a chunked stream mid-read correctly discards rest of request") { +KJ_TEST("CONNECT pipelined via an adapter (reject)") { KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + kj::TimerImpl timer(kj::origin()); HttpHeaderTable table; - ReadCancelHttpService service(timer, table); + ConnectRejectService service(table); HttpServer server(timer, table, service); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto listenTask KJ_UNUSED = server.listenHttp(kj::mv(pipe.ends[0])); - { - static constexpr kj::StringPtr REQUEST = - "POST / HTTP/1.1\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "6\r\n" - "fooba"_kj; // incomplete chunk - pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + bool rejectCalled = false; + bool failedAsExpected = false; - auto promise = expectRead(*pipe.ends[1], - "HTTP/1.1 408 Request Timeout\r\n" - "Content-Length: 0\r\n" - "\r\n"_kj); + auto client = newHttpClient(table, *pipe.ends[1]); + auto adaptedService = kj::newHttpService(*client).attach(kj::mv(client)); - KJ_EXPECT(!promise.poll(waitScope)); + // adaptedService is an HttpService that wraps an HttpClient that sends + // a request to server. - // Trigger timout, then response should be sent. - timer.advanceTo(timer.now() + 1 * kj::MILLISECONDS); - KJ_ASSERT(promise.poll(waitScope)); - promise.wait(waitScope); - } + auto clientPipe = newTwoWayPipe(); - // We left our request stream hanging. The server will try to read and discard the request body. - // Let's give it the rest of the data, followed by a second request. - { - static constexpr kj::StringPtr REQUEST = - "r\r\n" - "4a\r\n" - "this is some text that is the body of a chunk and not a valid chunk header\r\n" - "0\r\n" - "\r\n" - "GET / HTTP/1.1\r\n" - "\r\n"_kj; - pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + struct ResponseImpl final: public HttpService::ConnectResponse { + bool& rejectCalled; + kj::OneWayPipe pipe; + ResponseImpl(bool& rejectCalled) + : rejectCalled(rejectCalled), + pipe(kj::newOneWayPipe()) {} + void accept(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) override { + KJ_UNREACHABLE; + } - auto promise = expectRead(*pipe.ends[1], - "HTTP/1.1 200 OK\r\n" - "Content-Length: 0\r\n" - "\r\n"_kj); - KJ_ASSERT(promise.poll(waitScope)); - promise.wait(waitScope); - } + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + rejectCalled = true; + return kj::mv(pipe.out); + } + + kj::Own getRejectStream() { + return kj::mv(pipe.in); + } + }; + + ResponseImpl response(rejectCalled); + + HttpHeaderTable connectHeaderTable; + HttpHeaderTable tunneledHeaderTable; + HttpClientSettings settings; + + auto promise = adaptedService->connect("https://example.org"_kj, + HttpHeaders(connectHeaderTable), + *clientPipe.ends[0], + response, + {}).attach(kj::mv(clientPipe.ends[0])); + + auto proxyClient = newHttpClient(tunneledHeaderTable, *clientPipe.ends[1], settings) + .attach(kj::mv(clientPipe.ends[1])); + + auto text = proxyClient->request(HttpMethod::GET, + "http://example.org"_kj, + HttpHeaders(tunneledHeaderTable)) + .response.then([](HttpClient::Response&& response) mutable { + return response.body->readAllText().attach(kj::mv(response)); + }, [&](kj::Exception&& ex) -> kj::Promise { + // We fully expect the stream to fail here. + if (ex.getDescription() == "stream disconnected prematurely") { + failedAsExpected = true; + } + return kj::str("ok"); + }).wait(waitScope); + + auto rejectStream = response.getRejectStream(); + +#ifndef KJ_HTTP_TEST_USE_OS_PIPE + expectRead(*rejectStream, "boom"_kj).wait(waitScope); +#endif + + KJ_ASSERT(rejectCalled); + KJ_ASSERT(failedAsExpected); + KJ_ASSERT(text == "ok"); } } // namespace diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/http.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/http.c++ index f8bacf4e92a..aae47ad18b2 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/http.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/http.c++ @@ -20,15 +20,20 @@ // THE SOFTWARE. #include "http.h" +#include "kj/exception.h" #include "url.h" #include #include +#include #include #include #include #include #include #include +#if KJ_HAS_ZLIB +#include +#endif // KJ_HAS_ZLIB namespace kj { @@ -335,13 +340,17 @@ kj::StringPtr KJ_STRINGIFY(HttpMethod method) { return METHOD_NAMES[static_cast(method)]; } -static kj::Maybe consumeHttpMethod(char*& ptr) { +kj::StringPtr KJ_STRINGIFY(HttpConnectMethod method) { + return "CONNECT"_kj; +} + +static kj::Maybe> consumeHttpMethod(char*& ptr) { char* p = ptr; #define EXPECT_REST(prefix, suffix) \ if (strncmp(p, #suffix, sizeof(#suffix)-1) == 0) { \ ptr = p + (sizeof(#suffix)-1); \ - return HttpMethod::prefix##suffix; \ + return kj::Maybe>(HttpMethod::prefix##suffix); \ } else { \ return nullptr; \ } @@ -351,7 +360,18 @@ static kj::Maybe consumeHttpMethod(char*& ptr) { case 'C': switch (*p++) { case 'H': EXPECT_REST(CH,ECKOUT) - case 'O': EXPECT_REST(CO,PY) + case 'O': + switch (*p++) { + case 'P': EXPECT_REST(COP,Y) + case 'N': + if (strncmp(p, "NECT", 4) == 0) { + ptr = p + 4; + return kj::Maybe>(HttpConnectMethod()); + } else { + return nullptr; + } + default: return nullptr; + } default: return nullptr; } case 'D': EXPECT_REST(D,ELETE) @@ -413,6 +433,19 @@ static kj::Maybe consumeHttpMethod(char*& ptr) { } kj::Maybe tryParseHttpMethod(kj::StringPtr name) { + KJ_IF_MAYBE(method, tryParseHttpMethodAllowingConnect(name)) { + KJ_SWITCH_ONEOF(*method) { + KJ_CASE_ONEOF(m, HttpMethod) { return m; } + KJ_CASE_ONEOF(m, HttpConnectMethod) { return nullptr; } + } + KJ_UNREACHABLE; + } else { + return nullptr; + } +} + +kj::Maybe> tryParseHttpMethodAllowingConnect( + kj::StringPtr name) { // const_cast OK because we don't actually access it. consumeHttpMethod() is also called by some // code later than explicitly needs to use a non-const pointer. char* ptr = const_cast(name.begin()); @@ -528,7 +561,9 @@ struct HttpHeaderTable::IdsByNameMap { }; HttpHeaderTable::Builder::Builder() - : table(kj::heap()) {} + : table(kj::heap()) { + table->buildStatus = BuildStatus::BUILDING; +} HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) { requireValidHeaderName(name); @@ -576,7 +611,11 @@ bool HttpHeaders::isValidHeaderValue(kj::StringPtr value) { HttpHeaders::HttpHeaders(const HttpHeaderTable& table) : table(&table), - indexedHeaders(kj::heapArray(table.idCount())) {} + indexedHeaders(kj::heapArray(table.idCount())) { + KJ_ASSERT( + table.isReady(), "HttpHeaders object was constructed from " + "HttpHeaderTable that wasn't fully built yet at the time of construction"); +} void HttpHeaders::clear() { for (auto& header: indexedHeaders) { @@ -876,6 +915,23 @@ static char* trimHeaderEnding(kj::ArrayPtr content) { } HttpHeaders::RequestOrProtocolError HttpHeaders::tryParseRequest(kj::ArrayPtr content) { + KJ_SWITCH_ONEOF(tryParseRequestOrConnect(content)) { + KJ_CASE_ONEOF(request, Request) { + return kj::mv(request); + } + KJ_CASE_ONEOF(error, ProtocolError) { + return kj::mv(error); + } + KJ_CASE_ONEOF(connect, ConnectRequest) { + return ProtocolError { 501, "Not Implemented", + "Unrecognized request method.", content }; + } + } + KJ_UNREACHABLE; +} + +HttpHeaders::RequestConnectOrProtocolError HttpHeaders::tryParseRequestOrConnect( + kj::ArrayPtr content) { char* end = trimHeaderEnding(content); if (end == nullptr) { return ProtocolError { 400, "Bad Request", @@ -884,27 +940,36 @@ HttpHeaders::RequestOrProtocolError HttpHeaders::tryParseRequest(kj::ArrayPtr path; + KJ_IF_MAYBE(p, consumeWord(ptr)) { + path = *p; + } else { + return ProtocolError { 400, "Bad Request", + "Invalid request line.", content }; + } + + KJ_SWITCH_ONEOF(*method) { + KJ_CASE_ONEOF(m, HttpMethod) { + result = HttpHeaders::Request { m, KJ_ASSERT_NONNULL(path) }; + } + KJ_CASE_ONEOF(m, HttpConnectMethod) { + result = HttpHeaders::ConnectRequest { KJ_ASSERT_NONNULL(path) }; + } + } } else { return ProtocolError { 501, "Not Implemented", "Unrecognized request method.", content }; } - KJ_IF_MAYBE(path, consumeWord(ptr)) { - request.url = *path; - } else { - return ProtocolError { 400, "Bad Request", - "Invalid request line.", content }; - } - // Ignore rest of line. Don't care about "HTTP/1.1" or whatever. consumeLine(ptr); @@ -913,7 +978,7 @@ HttpHeaders::RequestOrProtocolError HttpHeaders::tryParseRequest(kj::ArrayPtr content) { @@ -983,6 +1048,12 @@ kj::String HttpHeaders::serializeRequest( return serialize(kj::toCharSequence(method), url, kj::StringPtr("HTTP/1.1"), connectionHeaders); } +kj::String HttpHeaders::serializeConnectRequest( + kj::StringPtr authority, + kj::ArrayPtr connectionHeaders) const { + return serialize("CONNECT"_kj, authority, kj::StringPtr("HTTP/1.1"), connectionHeaders); +} + kj::String HttpHeaders::serializeResponse( uint statusCode, kj::StringPtr statusText, kj::ArrayPtr connectionHeaders) const { @@ -1043,26 +1114,149 @@ kj::String HttpHeaders::toString() const { namespace { +template +class WrappableStreamMixin { + // Both HttpInputStreamImpl and HttpOutputStream are commonly wrapped by a class that implements + // a particular type of body stream, such as a chunked body or a fixed-length body. That wrapper + // stream is passed back to the application to represent the specific request/response body, but + // the inner stream is associated with the connection and can be reused several times. + // + // It's easy for applications to screw up and hold on to a body stream beyond the lifetime of the + // underlying connection stream. This used to lead to UAF. This mixin class implements behavior + // that detached the wrapper if it outlives the wrapped stream, so that we log errors and + +public: + WrappableStreamMixin() = default; + WrappableStreamMixin(WrappableStreamMixin&& other) { + // This constructor is only needed by HttpServer::Connection::makeHttpInput() which constructs + // a new stream and returns it. Technically the constructor will always be elided anyway. + KJ_REQUIRE(other.currentWrapper == nullptr, "can't move a wrappable object that has wrappers!"); + } + KJ_DISALLOW_COPY(WrappableStreamMixin); + + ~WrappableStreamMixin() noexcept(false) { + KJ_IF_MAYBE(w, currentWrapper) { + KJ_LOG(ERROR, "HTTP connection destroyed while HTTP body streams still exist", + kj::getStackTrace()); + *w = nullptr; + } + } + + void setCurrentWrapper(kj::Maybe& weakRef) { + // Tracks the current `HttpEntityBodyReader` instance which is wrapping this stream. There can + // be only one wrapper at a time, and the wrapper must be destroyed before the underlying HTTP + // connection is torn down. The purpose of tracking the wrapper here is to detect when these + // rules are violated by apps, and log an error instead of going UB. + // + // `weakRef` is the wrapper's pointer to this object. If the underlying stream is destroyed + // before the wrapper, then `weakRef` will be nulled out. + + // The API should prevent an app from obtaining multiple wrappers with the same backing stream. + KJ_ASSERT(currentWrapper == nullptr, + "bug in KJ HTTP: only one HTTP stream wrapper can exist at a time"); + + currentWrapper = weakRef; + weakRef = static_cast(*this); + } + + void unsetCurrentWrapper(kj::Maybe& weakRef) { + auto& current = KJ_ASSERT_NONNULL(currentWrapper); + KJ_ASSERT(¤t == &weakRef, + "bug in KJ HTTP: unsetCurrentWrapper() passed the wrong wrapper"); + weakRef = nullptr; + currentWrapper = nullptr; + } + +private: + kj::Maybe&> currentWrapper; +}; + +// ======================================================================================= + static constexpr size_t MIN_BUFFER = 4096; static constexpr size_t MAX_BUFFER = 128 * 1024; static constexpr size_t MAX_CHUNK_HEADER_SIZE = 32; -class HttpInputStreamImpl final: public HttpInputStream { +class HttpInputStreamImpl final: public HttpInputStream, + public WrappableStreamMixin { +private: + static kj::OneOf getResumingRequest( + kj::OneOf method, + kj::StringPtr url) { + KJ_SWITCH_ONEOF(method) { + KJ_CASE_ONEOF(m, HttpMethod) { + return HttpHeaders::Request { m, url }; + } + KJ_CASE_ONEOF(m, HttpConnectMethod) { + return HttpHeaders::ConnectRequest { url }; + } + } + KJ_UNREACHABLE; + } public: explicit HttpInputStreamImpl(AsyncInputStream& inner, const HttpHeaderTable& table) : inner(inner), headerBuffer(kj::heapArray(MIN_BUFFER)), headers(table) { } + explicit HttpInputStreamImpl(AsyncInputStream& inner, + kj::Array headerBufferParam, + kj::ArrayPtr leftoverParam, + kj::OneOf method, + kj::StringPtr url, + HttpHeaders headers) + : inner(inner), + headerBuffer(kj::mv(headerBufferParam)), + // Initialize `messageHeaderEnd` to a safe value, we'll adjust it below. + messageHeaderEnd(leftoverParam.begin() - headerBuffer.begin()), + leftover(leftoverParam), + headers(kj::mv(headers)), + resumingRequest(getResumingRequest(method, url)) { + // Constructor used for resuming a SuspendedRequest. + + // We expect headerBuffer to look like this: + // [CR] LF + // We initialized `messageHeaderEnd` to the beginning of `leftover`, but we want to point it at + // the CR (or LF if there's no CR). + KJ_REQUIRE(messageHeaderEnd >= 2 && leftover.end() <= headerBuffer.end(), + "invalid SuspendedRequest - leftover buffer not where it should be"); + KJ_REQUIRE(leftover.begin()[-1] == '\n', "invalid SuspendedRequest - missing LF"); + messageHeaderEnd -= 1 + (leftover.begin()[-2] == '\r'); + + // We're in the middle of a message, so set up our state as such. Note that the only way to + // resume a SuspendedRequest is via an HttpServer, but HttpServers never call + // `awaitNextMessage()` before fully reading request bodies, meaning we expect that + // `messageReadQueue` will never be used. + ++pendingMessageCount; + auto paf = kj::newPromiseAndFulfiller(); + onMessageDone = kj::mv(paf.fulfiller); + messageReadQueue = kj::mv(paf.promise); + } + bool canReuse() { return !broken && pendingMessageCount == 0; } + bool canSuspend() { + // We are at a suspendable point if we've parsed the headers, but haven't consumed anything + // beyond that. + // + // TODO(cleanup): This is a silly check; we need a more defined way to track the state of the + // stream. + bool messageHeaderEndLooksRight = + (leftover.begin() - (headerBuffer.begin() + messageHeaderEnd) == 2 && + leftover.begin()[-1] == '\n' && leftover.begin()[-2] == '\r') + || (leftover.begin() - (headerBuffer.begin() + messageHeaderEnd) == 1 && + leftover.begin()[-1] == '\n'); + + return !broken && headerBuffer.size() > 0 && messageHeaderEndLooksRight; + } + // --------------------------------------------------------------------------- // public interface kj::Promise readRequest() override { return readRequestHeaders() - .then([this](HttpHeaders::RequestOrProtocolError&& requestOrProtocolError) + .then([this](HttpHeaders::RequestConnectOrProtocolError&& requestOrProtocolError) -> HttpInputStream::Request { auto request = KJ_REQUIRE_NONNULL( requestOrProtocolError.tryGet(), "bad request"); @@ -1072,6 +1266,27 @@ public: }); } + kj::Promise> readRequestAllowingConnect() override { + return readRequestHeaders() + .then([this](HttpHeaders::RequestConnectOrProtocolError&& requestOrProtocolError) + -> kj::OneOf { + KJ_SWITCH_ONEOF(requestOrProtocolError) { + KJ_CASE_ONEOF(request, HttpHeaders::Request) { + auto body = getEntityBody(HttpInputStreamImpl::REQUEST, request.method, 0, headers); + return HttpInputStream::Request { request.method, request.url, headers, kj::mv(body) }; + } + KJ_CASE_ONEOF(request, HttpHeaders::ConnectRequest) { + auto body = getEntityBody(HttpInputStreamImpl::REQUEST, HttpConnectMethod(), 0, headers); + return HttpInputStream::Connect { request.authority, headers, kj::mv(body) }; + } + KJ_CASE_ONEOF(error, HttpHeaders::ProtocolError) { + KJ_FAIL_REQUIRE("bad request"); + } + } + KJ_UNREACHABLE; + }); + } + kj::Promise readResponse(HttpMethod requestMethod) override { return readResponseHeaders() .then([this,requestMethod](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) @@ -1128,6 +1343,11 @@ public: // Used on the client to detect when idle connections are closed from the server end. (In this // case, the promise always returns false or is canceled.) + if (resumingRequest != nullptr) { + // We're resuming a request, so report that we have a message. + return true; + } + if (onMessageDone != nullptr) { // We're still working on reading the previous body. auto fork = messageReadQueue.fork(); @@ -1166,10 +1386,10 @@ public: auto paf = kj::newPromiseAndFulfiller(); auto promise = messageReadQueue - .then(kj::mvCapture(paf.fulfiller, [this](kj::Own> fulfiller) { + .then([this,fulfiller=kj::mv(paf.fulfiller)]() mutable { onMessageDone = kj::mv(fulfiller); return readHeader(HeaderType::MESSAGE, 0, 0); - })); + }); messageReadQueue = kj::mv(paf.promise); @@ -1202,10 +1422,15 @@ public: }); } - inline kj::Promise readRequestHeaders() { + inline kj::Promise readRequestHeaders() { + KJ_IF_MAYBE(resuming, resumingRequest) { + KJ_DEFER(resumingRequest = nullptr); + return HttpHeaders::RequestConnectOrProtocolError(*resuming); + } + return readMessageHeaders().then([this](kj::ArrayPtr text) { headers.clear(); - return headers.tryParseRequest(text); + return headers.tryParseRequestOrConnect(text); }); } @@ -1257,7 +1482,9 @@ public: }; kj::Own getEntityBody( - RequestOrResponse type, HttpMethod method, uint statusCode, + RequestOrResponse type, + kj::OneOf method, + uint statusCode, const kj::HttpHeaders& headers); struct ReleasedBuffer { @@ -1283,6 +1510,9 @@ private: HttpHeaders headers; // Parsed headers, after a call to parseAwaited*(). + kj::Maybe> resumingRequest; + // Non-null if we're resuming a SuspendedRequest. + bool lineBreakBeforeNextHeader = false; // If true, the next await should expect to start with a spurious '\n' or '\r\n'. This happens // as a side-effect of HTTP chunked encoding, where such a newline is added to the end of each @@ -1465,25 +1695,44 @@ private: class HttpEntityBodyReader: public kj::AsyncInputStream { public: - HttpEntityBodyReader(HttpInputStreamImpl& inner): inner(inner) {} + HttpEntityBodyReader(HttpInputStreamImpl& inner) { + inner.setCurrentWrapper(weakInner); + } ~HttpEntityBodyReader() noexcept(false) { if (!finished) { - inner.abortRead(); + KJ_IF_MAYBE(inner, weakInner) { + inner->unsetCurrentWrapper(weakInner); + inner->abortRead(); + } else { + // Since we're in a destructor, log an error instead of throwing. + KJ_LOG(ERROR, "HTTP body input stream outlived underlying connection", kj::getStackTrace()); + } } } protected: - HttpInputStreamImpl& inner; + HttpInputStreamImpl& getInner() { + KJ_IF_MAYBE(i, weakInner) { + return *i; + } else if (finished) { + // This is a bug in the implementations in this file, not the app. + KJ_FAIL_ASSERT("bug in KJ HTTP: tried to access inner stream after it had been released"); + } else { + KJ_FAIL_REQUIRE("HTTP body input stream outlived underlying connection"); + } + } void doneReading() { - KJ_REQUIRE(!finished); + auto& inner = getInner(); + inner.unsetCurrentWrapper(weakInner); finished = true; inner.finishRead(); } - inline bool alreadyDone() { return finished; } + inline bool alreadyDone() { return weakInner == nullptr; } private: + kj::Maybe weakInner; bool finished = false; }; @@ -1500,7 +1749,7 @@ public: } Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return size_t(0); + return constPromise(); } Maybe tryGetLength() override { @@ -1519,9 +1768,9 @@ public: : HttpEntityBodyReader(inner) {} Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - if (alreadyDone()) return size_t(0); + if (alreadyDone()) return constPromise(); - return inner.tryRead(buffer, minBytes, maxBytes) + return getInner().tryRead(buffer, minBytes, maxBytes) .then([=](size_t amount) { if (amount < minBytes) { doneReading(); @@ -1545,19 +1794,25 @@ public: } Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + KJ_REQUIRE(clean, "can't read more data after a previous read didn't complete"); + clean = false; return tryReadInternal(buffer, minBytes, maxBytes, 0); } private: size_t length; + bool clean = true; Promise tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead) { - if (length == 0) return size_t(0); + if (length == 0) { + clean = true; + return constPromise(); + } // We have to set minBytes to 1 here so that if we read any data at all, we update our // counter immediately, so that we still know where we are in case of cancellation. - return inner.tryRead(buffer, 1, kj::min(maxBytes, length)) + return getInner().tryRead(buffer, 1, kj::min(maxBytes, length)) .then([=](size_t amount) -> kj::Promise { length -= amount; if (length > 0) { @@ -1574,6 +1829,7 @@ private: } else if (length == 0) { doneReading(); } + clean = true; return amount + alreadyRead; }); } @@ -1587,19 +1843,23 @@ public: : HttpEntityBodyReader(inner) {} Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + KJ_REQUIRE(clean, "can't read more data after a previous read didn't complete"); + clean = false; return tryReadInternal(buffer, minBytes, maxBytes, 0); } private: size_t chunkSize = 0; + bool clean = true; Promise tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead) { if (alreadyDone()) { + clean = true; return alreadyRead; } else if (chunkSize == 0) { // Read next chunk header. - return inner.readChunkHeader().then([=](uint64_t nextChunkSize) { + return getInner().readChunkHeader().then([=](uint64_t nextChunkSize) { if (nextChunkSize == 0) { doneReading(); } @@ -1611,7 +1871,7 @@ private: // Read current chunk. // We have to set minBytes to 1 here so that if we read any data at all, we update our // counter immediately, so that we still know where we are in case of cancellation. - return inner.tryRead(buffer, 1, kj::min(maxBytes, chunkSize)) + return getInner().tryRead(buffer, 1, kj::min(maxBytes, chunkSize)) .then([=](size_t amount) -> kj::Promise { chunkSize -= amount; if (amount == 0) { @@ -1622,6 +1882,7 @@ private: return tryReadInternal(reinterpret_cast(buffer) + amount, minBytes - amount, maxBytes - amount, alreadyRead + amount); } + clean = true; return alreadyRead + amount; }); } @@ -1661,35 +1922,49 @@ static_assert(!fastCaseCmp<'f','O','o','B'>("FooB1"), ""); static_assert(!fastCaseCmp<'f','O','o','B','1','a'>("FooB1"), ""); kj::Own HttpInputStreamImpl::getEntityBody( - RequestOrResponse type, HttpMethod method, uint statusCode, + RequestOrResponse type, + kj::OneOf method, + uint statusCode, const kj::HttpHeaders& headers) { + KJ_REQUIRE(headerBuffer.size() > 0, "Cannot get entity body after header buffer release."); + + auto isHeadRequest = method.tryGet().map([](auto& m) { + return m == HttpMethod::HEAD; + }).orDefault(false); + + auto isConnectRequest = method.is(); + // Rules to determine how HTTP entity-body is delimited: // https://tools.ietf.org/html/rfc7230#section-3.3.3 - // #1 if (type == RESPONSE) { - if (method == HttpMethod::HEAD) { + if (isHeadRequest) { // Body elided. kj::Maybe length; KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) { length = strtoull(cl->cStr(), nullptr, 10); } else if (headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr) { - // HACK: Neither Content-Length nor Transfer-Encoding header in response to HEAD request. - // Propagate this fact with a 0 expected body length. + // HACK: Neither Content-Length nor Transfer-Encoding header in response to HEAD + // request. Propagate this fact with a 0 expected body length. length = uint64_t(0); } return kj::heap(*this, length); + } else if (isConnectRequest && statusCode >= 200 && statusCode < 300) { + KJ_FAIL_ASSERT("a CONNECT response with a 2xx status does not have an entity body to get"); } else if (statusCode == 204 || statusCode == 304) { // No body. return kj::heap(*this, uint64_t(0)); } } - // #2 deals with the CONNECT method which is handled separately. + // For CONNECT requests messages, we let the rest of the logic play out. + // We already check before here to ensure that Transfer-Encoding and + // Content-Length headers are not present in which case the code below + // does the right thing. // #3 KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) { - // TODO(someday): Support plugable transfer encodings? Or at least gzip? + // TODO(someday): Support pluggable transfer encodings? Or at least gzip? // TODO(someday): Support stacked transfer encodings, e.g. "gzip, chunked". // NOTE: #3¶3 is ambiguous about what should happen if Transfer-Encoding and Content-Length are @@ -1705,9 +1980,9 @@ kj::Own HttpInputStreamImpl::getEntityBody( // #3¶2 KJ_REQUIRE(type != REQUEST, "request body cannot have Transfer-Encoding other than chunked"); return kj::heap(*this); - } else { - KJ_FAIL_REQUIRE("unknown transfer encoding", *te) { break; } } + + KJ_FAIL_REQUIRE("unknown transfer encoding", *te) { break; }; } // #4 and #5 @@ -1760,7 +2035,7 @@ kj::Own newHttpInputStream( namespace { -class HttpOutputStream { +class HttpOutputStream: public WrappableStreamMixin { public: HttpOutputStream(AsyncOutputStream& inner): inner(inner) {} @@ -1902,14 +2177,59 @@ private: // is empty, then they make the write directly, using `writeInProgress` to detect and block // concurrent writes. - writeQueue = writeQueue.then(kj::mvCapture(content, [this](kj::String&& content) { + writeQueue = writeQueue.then([this,content=kj::mv(content)]() mutable { auto promise = inner.write(content.begin(), content.size()); return promise.attach(kj::mv(content)); - })); + }); + } +}; + +class HttpEntityBodyWriter: public kj::AsyncOutputStream { +public: + HttpEntityBodyWriter(HttpOutputStream& inner) { + inner.setCurrentWrapper(weakInner); + } + ~HttpEntityBodyWriter() noexcept(false) { + if (!finished) { + KJ_IF_MAYBE(inner, weakInner) { + inner->unsetCurrentWrapper(weakInner); + inner->abortBody(); + } else { + // Since we're in a destructor, log an error instead of throwing. + KJ_LOG(ERROR, "HTTP body output stream outlived underlying connection", + kj::getStackTrace()); + } + } + } + +protected: + HttpOutputStream& getInner() { + KJ_IF_MAYBE(i, weakInner) { + return *i; + } else if (finished) { + // This is a bug in the implementations in this file, not the app. + KJ_FAIL_ASSERT("bug in KJ HTTP: tried to access inner stream after it had been released"); + } else { + KJ_FAIL_REQUIRE("HTTP body output stream outlived underlying connection"); + } } + + void doneWriting() { + auto& inner = getInner(); + inner.unsetCurrentWrapper(weakInner); + finished = true; + inner.finishBody(); + } + + inline bool alreadyDone() { return weakInner == nullptr; } + +private: + kj::Maybe weakInner; + bool finished = false; }; class HttpNullEntityWriter final: public kj::AsyncOutputStream { + // Does not inherit HttpEntityBodyWriter because it doesn't actually write anything. public: Promise write(const void* buffer, size_t size) override { return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()"); @@ -1923,6 +2243,7 @@ public: }; class HttpDiscardingEntityWriter final: public kj::AsyncOutputStream { + // Does not inherit HttpEntityBodyWriter because it doesn't actually write anything. public: Promise write(const void* buffer, size_t size) override { return kj::READY_NOW; @@ -1935,16 +2256,11 @@ public: } }; -class HttpFixedLengthEntityWriter final: public kj::AsyncOutputStream { +class HttpFixedLengthEntityWriter final: public HttpEntityBodyWriter { public: HttpFixedLengthEntityWriter(HttpOutputStream& inner, uint64_t length) - : inner(inner), length(length) { - if (length == 0) inner.finishBody(); - } - ~HttpFixedLengthEntityWriter() noexcept(false) { - if (length > 0 || inner.isWriteInProgress()) { - inner.abortBody(); - } + : HttpEntityBodyWriter(inner), length(length) { + if (length == 0) doneWriting(); } Promise write(const void* buffer, size_t size) override { @@ -1952,7 +2268,7 @@ public: KJ_REQUIRE(size <= length, "overwrote Content-Length"); length -= size; - return maybeFinishAfter(inner.writeBodyData(buffer, size)); + return maybeFinishAfter(getInner().writeBodyData(buffer, size)); } Promise write(ArrayPtr> pieces) override { uint64_t size = 0; @@ -1962,11 +2278,11 @@ public: KJ_REQUIRE(size <= length, "overwrote Content-Length"); length -= size; - return maybeFinishAfter(inner.writeBodyData(pieces)); + return maybeFinishAfter(getInner().writeBodyData(pieces)); } Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { - if (amount == 0) return Promise(uint64_t(0)); + if (amount == 0) return constPromise(); bool overshot = amount > length; if (overshot) { @@ -1987,10 +2303,10 @@ public: auto promise = amount == 0 ? kj::Promise(amount) - : inner.pumpBodyFrom(input, amount).then([this,amount](uint64_t actual) { + : getInner().pumpBodyFrom(input, amount).then([this,amount](uint64_t actual) { // Adjust for bytes not written. length += amount - actual; - if (length == 0) inner.finishBody(); + if (length == 0) doneWriting(); return actual; }); @@ -2016,32 +2332,32 @@ public: } Promise whenWriteDisconnected() override { - return inner.whenWriteDisconnected(); + return getInner().whenWriteDisconnected(); } private: - HttpOutputStream& inner; uint64_t length; kj::Promise maybeFinishAfter(kj::Promise promise) { if (length == 0) { - return promise.then([this]() { inner.finishBody(); }); + return promise.then([this]() { doneWriting(); }); } else { return kj::mv(promise); } } }; -class HttpChunkedEntityWriter final: public kj::AsyncOutputStream { +class HttpChunkedEntityWriter final: public HttpEntityBodyWriter { public: HttpChunkedEntityWriter(HttpOutputStream& inner) - : inner(inner) {} + : HttpEntityBodyWriter(inner) {} ~HttpChunkedEntityWriter() noexcept(false) { - if (inner.canWriteBodyData()) { - inner.writeBodyData(kj::str("0\r\n\r\n")); - inner.finishBody(); - } else { - inner.abortBody(); + if (!alreadyDone()) { + auto& inner = getInner(); + if (inner.canWriteBodyData()) { + inner.writeBodyData(kj::str("0\r\n\r\n")); + doneWriting(); + } } } @@ -2054,7 +2370,7 @@ public: parts[1] = kj::arrayPtr(reinterpret_cast(buffer), size); parts[2] = kj::StringPtr("\r\n").asBytes(); - auto promise = inner.writeBodyData(parts.asPtr()); + auto promise = getInner().writeBodyData(parts.asPtr()); return promise.attach(kj::mv(header), kj::mv(parts)); } @@ -2073,7 +2389,7 @@ public: partsBuilder.add(kj::StringPtr("\r\n").asBytes()); auto parts = partsBuilder.finish(); - auto promise = inner.writeBodyData(parts.asPtr()); + auto promise = getInner().writeBodyData(parts.asPtr()); return promise.attach(kj::mv(header), kj::mv(parts)); } @@ -2082,9 +2398,11 @@ public: // Hey, we know exactly how large the input is, so we can write just one chunk. uint64_t length = kj::min(amount, *l); + auto& inner = getInner(); inner.writeBodyData(kj::str(kj::hex(length), "\r\n")); return inner.pumpBodyFrom(input, length) .then([this,length](uint64_t actual) { + auto& inner = getInner(); if (actual < length) { inner.abortBody(); KJ_FAIL_REQUIRE( @@ -2103,25 +2421,36 @@ public: } Promise whenWriteDisconnected() override { - return inner.whenWriteDisconnected(); + return getInner().whenWriteDisconnected(); } - -private: - HttpOutputStream& inner; }; // ======================================================================================= -class WebSocketImpl final: public WebSocket { +class WebSocketImpl final: public WebSocket, private WebSocketErrorHandler { public: WebSocketImpl(kj::Own stream, kj::Maybe maskKeyGenerator, + kj::Maybe compressionConfigParam = nullptr, + kj::Maybe errorHandler = nullptr, kj::Array buffer = kj::heapArray(4096), kj::ArrayPtr leftover = nullptr, kj::Maybe> waitBeforeSend = nullptr) : stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator), + compressionConfig(kj::mv(compressionConfigParam)), + errorHandler(errorHandler.orDefault(*this)), sendingPong(kj::mv(waitBeforeSend)), - recvBuffer(kj::mv(buffer)), recvData(leftover) {} + recvBuffer(kj::mv(buffer)), recvData(leftover) { +#if KJ_HAS_ZLIB + KJ_IF_MAYBE(config, compressionConfig) { + compressionContext.emplace(ZlibContext::Mode::COMPRESS, *config); + decompressionContext.emplace(ZlibContext::Mode::DECOMPRESS, *config); + } +#else + KJ_REQUIRE(compressionConfig == nullptr, + "WebSocket compression is only supported if KJ is compiled with Zlib."); +#endif // KJ_HAS_ZLIB + } kj::Promise send(kj::ArrayPtr message) override { return sendImpl(OPCODE_BINARY, message); @@ -2211,36 +2540,63 @@ public: } auto& recvHeader = *reinterpret_cast(recvData.begin()); + if (recvHeader.hasRsv2or3()) { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Received frame had RSV bits 2 or 3 set", + }); + } recvData = recvData.slice(headerSize, recvData.size()); size_t payloadLen = recvHeader.getPayloadLen(); - - KJ_REQUIRE(payloadLen < maxSize, "WebSocket message is too large"); + if (payloadLen > maxSize) { + return errorHandler.handleWebSocketProtocolError({ + 1009, kj::str("Message is too large: ", payloadLen, " > ", maxSize) + }); + } auto opcode = recvHeader.getOpcode(); bool isData = opcode < OPCODE_FIRST_CONTROL; if (opcode == OPCODE_CONTINUATION) { - KJ_REQUIRE(!fragments.empty(), "unexpected continuation frame in WebSocket"); + if (fragments.empty()) { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Unexpected continuation frame" + }); + } opcode = fragmentOpcode; } else if (isData) { - KJ_REQUIRE(fragments.empty(), "expected continuation frame in WebSocket"); + if (!fragments.empty()) { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Missing continuation frame" + }); + } } bool isFin = recvHeader.isFin(); + bool isCompressed = false; kj::Array message; // space to allocate byte* payloadTarget; // location into which to read payload (size is payloadLen) + kj::Maybe originalMaxSize; // maxSize from first `receive()` call if (isFin) { - // Add space for NUL terminator when allocating text message. - size_t amountToAllocate = payloadLen + (opcode == OPCODE_TEXT && isFin); + size_t amountToAllocate; + if (recvHeader.isCompressed() || fragmentCompressed) { + // Add 4 since we append 0x00 0x00 0xFF 0xFF to the tail of the payload. + // See: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2 + amountToAllocate = payloadLen + 4; + isCompressed = true; + } else { + // Add space for NUL terminator when allocating text message. + amountToAllocate = payloadLen + (opcode == OPCODE_TEXT && isFin); + } if (isData && !fragments.empty()) { // Final frame of a fragmented message. Gather the fragments. size_t offset = 0; for (auto& fragment: fragments) offset += fragment.size(); message = kj::heapArray(offset + amountToAllocate); + originalMaxSize = offset + maxSize; // gives us back the original maximum message size. offset = 0; for (auto& fragment: fragments) { @@ -2251,28 +2607,36 @@ public: fragments.clear(); fragmentOpcode = 0; + fragmentCompressed = false; } else { // Single-frame message. message = kj::heapArray(amountToAllocate); + originalMaxSize = maxSize; // gives us back the original maximum message size. payloadTarget = message.begin(); } } else { // Fragmented message, and this isn't the final fragment. - KJ_REQUIRE(isData, "WebSocket control frame cannot be fragmented"); + if (!isData) { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Received fragmented control frame" + }); + } message = kj::heapArray(payloadLen); payloadTarget = message.begin(); if (fragments.empty()) { // This is the first fragment, so set the opcode. fragmentOpcode = opcode; + fragmentCompressed = recvHeader.isCompressed(); } } Mask mask = recvHeader.getMask(); - auto handleMessage = kj::mvCapture(message, - [this,opcode,payloadTarget,payloadLen,mask,isFin,maxSize] - (kj::Array&& message) -> kj::Promise { + auto handleMessage = + [this,opcode,payloadTarget,payloadLen,mask,isFin,maxSize,originalMaxSize, + isCompressed,message=kj::mv(message)]() mutable + -> kj::Promise { if (!mask.isZero()) { mask.apply(kj::arrayPtr(payloadTarget, payloadLen)); } @@ -2284,14 +2648,66 @@ public: return receive(newMax); } + // Provide a reasonable error if a compressed frame is received without compression enabled. + if (isCompressed && compressionConfig == nullptr) { + return errorHandler.handleWebSocketProtocolError({ + 1002, kj::str( + "Received a WebSocket frame whose compression bit was set, but the compression " + "extension was not negotiated for this connection.") + }); + } + switch (opcode) { case OPCODE_CONTINUATION: // Shouldn't get here; handled above. KJ_UNREACHABLE; case OPCODE_TEXT: +#if KJ_HAS_ZLIB + if (isCompressed) { + auto& config = KJ_ASSERT_NONNULL(compressionConfig); + auto& decompressor = KJ_ASSERT_NONNULL(decompressionContext); + KJ_ASSERT(message.size() >= 4); + auto tail = message.slice(message.size() - 4, message.size()); + // Note that we added an additional 4 bytes to `message`s capacity to account for these + // extra bytes. See `amountToAllocate` in the if(recvHeader.isCompressed()) block above. + const byte tailBytes[] = {0x00, 0x00, 0xFF, 0xFF}; + memcpy(tail.begin(), tailBytes, sizeof(tailBytes)); + // We have to append 0x00 0x00 0xFF 0xFF to the message before inflating. + // See: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2 + if (config.inboundNoContextTakeover) { + // We must reset context on each message. + decompressor.reset(); + } + bool addNullTerminator = true; + // We want to add the null terminator when receiving a TEXT message. + auto decompressed = decompressor.processMessage(message, originalMaxSize, + addNullTerminator); + return Message(kj::String(decompressed.releaseAsChars())); + } +#endif // KJ_HAS_ZLIB message.back() = '\0'; return Message(kj::String(message.releaseAsChars())); case OPCODE_BINARY: +#if KJ_HAS_ZLIB + if (isCompressed) { + auto& config = KJ_ASSERT_NONNULL(compressionConfig); + auto& decompressor = KJ_ASSERT_NONNULL(decompressionContext); + KJ_ASSERT(message.size() >= 4); + auto tail = message.slice(message.size() - 4, message.size()); + // Note that we added an additional 4 bytes to `message`s capacity to account for these + // extra bytes. See `amountToAllocate` in the if(recvHeader.isCompressed()) block above. + const byte tailBytes[] = {0x00, 0x00, 0xFF, 0xFF}; + memcpy(tail.begin(), tailBytes, sizeof(tailBytes)); + // We have to append 0x00 0x00 0xFF 0xFF to the message before inflating. + // See: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2 + if (config.inboundNoContextTakeover) { + // We must reset context on each message. + decompressor.reset(); + } + auto decompressed = decompressor.processMessage(message, originalMaxSize); + return Message(decompressed.releaseAsBytes()); + } +#endif // KJ_HAS_ZLIB return Message(message.releaseAsBytes()); case OPCODE_CLOSE: if (message.size() < 2) { @@ -2311,9 +2727,11 @@ public: // Unsolicited pong. Ignore. return receive(maxSize); default: - KJ_FAIL_REQUIRE("unknown WebSocket opcode", opcode); + return errorHandler.handleWebSocketProtocolError({ + 1002, kj::str("Unknown opcode ", opcode) + }); } - }); + }; if (payloadLen <= recvData.size()) { // All data already received. @@ -2348,6 +2766,27 @@ public: return nullptr; } + KJ_IF_MAYBE(config, compressionConfig) { + KJ_IF_MAYBE(otherConfig, optOther->compressionConfig) { + if (config->outboundMaxWindowBits != otherConfig->inboundMaxWindowBits || + config->inboundMaxWindowBits != otherConfig->outboundMaxWindowBits || + config->inboundNoContextTakeover!= otherConfig->outboundNoContextTakeover || + config->outboundNoContextTakeover!= otherConfig->inboundNoContextTakeover) { + // Compression configurations differ. + return nullptr; + } + } else { + // Only one websocket uses compression. + return nullptr; + } + } else { + if (optOther->compressionConfig != nullptr) { + // Only one websocket uses compression. + return nullptr; + } + } + // Both websockets use compatible compression configurations so we can pump directly. + // Check same error conditions as with sendImpl(). KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); KJ_REQUIRE(!currentlySending, "another message send is already in progress"); @@ -2371,6 +2810,54 @@ public: uint64_t receivedByteCount() override { return receivedBytes; } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + if (maskKeyGenerator == nullptr) { + // `this` is the server side of a websocket. + if (ctx == ExtensionsContext::REQUEST) { + // The other WebSocket is (going to be) the client side of a WebSocket, i.e. this is a + // proxying pass-through scenario. Optimization is possible. Confusingly, we have to use + // generateExtensionResponse() (even though we're generating headers to be passed in a + // request) because this is the function that correctly maps our config's inbound/outbound + // to client/server. + KJ_IF_MAYBE(c, compressionConfig) { + return _::generateExtensionResponse(*c); + } else { + return kj::String(nullptr); // recommend no compression + } + } else { + // We're apparently arranging to pump from the server side of one WebSocket to the server + // side of another; i.e., we are a server, we have two clients, and we're trying to pump + // between them. We cannot optimize this case, because the masking requirements are + // different for client->server vs. server->client messages. Since we have to parse out + // the messages anyway there's no point in trying to match extensions, so return null. + return nullptr; + } + } else { + // `this` is the client side of a websocket. + if (ctx == ExtensionsContext::RESPONSE) { + // The other WebSocket is (going to be) the server side of a WebSocket, i.e. this is a + // proxying pass-through scenario. Optimization is possible. Confusingly, we have to use + // generateExtensionRequest() (even though we're generating headers to be passed in a + // response) because this is the function that correctly maps our config's inbound/outbound + // to server/client. + KJ_IF_MAYBE(c, compressionConfig) { + CompressionParameters arr[1]{*c}; + return _::generateExtensionRequest(arr); + } else { + return kj::String(nullptr); // recommend no compression + } + } else { + // We're apparently arranging to pump from the client side of one WebSocket to the client + // side of another; i.e., we are a client, we are connected to two servers, and we're + // trying to pump between them. We cannot optimize this case, because the masking + // requirements are different for client->server vs. server->client messages. Since we have + // to parse out the messages anyway there's no point in trying to match extensions, so + // return null. + return nullptr; + } + } + } + private: class Mask { public: @@ -2409,8 +2896,10 @@ private: class Header { public: - kj::ArrayPtr compose(bool fin, byte opcode, uint64_t payloadLen, Mask mask) { - bytes[0] = (fin ? FIN_MASK : 0) | opcode; + kj::ArrayPtr compose(bool fin, bool compressed, byte opcode, uint64_t payloadLen, + Mask mask) { + bytes[0] = (fin ? FIN_MASK : 0) | (compressed ? RSV1_MASK : 0) | opcode; + // Note that we can only set the compressed bit on DATA frames. bool hasMask = !mask.isZero(); size_t fill; @@ -2458,8 +2947,12 @@ private: return bytes[0] & FIN_MASK; } - bool hasRsv() const { - return bytes[0] & RSV_MASK; + bool isCompressed() const { + return bytes[0] & RSV1_MASK; + } + + bool hasRsv2or3() const { + return bytes[0] & RSV2_3_MASK; } byte getOpcode() const { @@ -2523,13 +3016,211 @@ private: byte bytes[14]; static constexpr byte FIN_MASK = 0x80; - static constexpr byte RSV_MASK = 0x70; + static constexpr byte RSV2_3_MASK = 0x30; + static constexpr byte RSV1_MASK = 0x40; static constexpr byte OPCODE_MASK = 0x0f; static constexpr byte USE_MASK_MASK = 0x80; static constexpr byte PAYLOAD_LEN_MASK = 0x7f; }; +#if KJ_HAS_ZLIB + class ZlibContext { + // `ZlibContext` is the WebSocket's interface to Zlib's compression/decompression functions. + // Depending on the `mode`, `ZlibContext` will act as a compressor or a decompressor. + public: + enum class Mode { + COMPRESS, + DECOMPRESS, + }; + + struct Result { + int processResult = 0; + kj::Array buffer; + size_t size = 0; // Number of bytes used; size <= buffer.size(). + }; + + ZlibContext(Mode mode, const CompressionParameters& config) : mode(mode) { + switch (mode) { + case Mode::COMPRESS: { + int windowBits = -config.outboundMaxWindowBits.orDefault(15); + // We use negative values because we want to use raw deflate. + if(windowBits == -8) { + // Zlib cannot accept `windowBits` of 8 for the deflater. However, due to an + // implementation quirk, `windowBits` of 8 and 9 would both use 250 bytes. + // Therefore, a decompressor using `windowBits` of 8 could safely inflate a message + // that a zlib client compressed using `windowBits` = 9. + // https://bugs.chromium.org/p/chromium/issues/detail?id=691074 + windowBits = -9; + } + int result = deflateInit2( + &ctx, + Z_DEFAULT_COMPRESSION, + Z_DEFLATED, + windowBits, + 8, // memLevel = 8 is the default + Z_DEFAULT_STRATEGY); + KJ_REQUIRE(result == Z_OK, "Failed to initialize compression context (deflate)."); + break; + } + case Mode::DECOMPRESS: { + int windowBits = -config.inboundMaxWindowBits.orDefault(15); + // We use negative values because we want to use raw inflate. + int result = inflateInit2(&ctx, windowBits); + KJ_REQUIRE(result == Z_OK, "Failed to initialize decompression context (inflate)."); + break; + } + } + } + + ~ZlibContext() noexcept(false) { + switch (mode) { + case Mode::COMPRESS: + deflateEnd(&ctx); + break; + case Mode::DECOMPRESS: + inflateEnd(&ctx); + break; + } + } + + KJ_DISALLOW_COPY_AND_MOVE(ZlibContext); + + kj::Array processMessage(kj::ArrayPtr message, + kj::Maybe maxSize = nullptr, + bool addNullTerminator = false) { + // If `this` is the compressor, calling `processMessage()` will compress the `message`. + // Likewise, if `this` is the decompressor, `processMessage()` will decompress the `message`. + // + // `maxSize` is only passed in when decompressing, since we want to ensure the decompressed + // message is smaller than the `maxSize` passed to `receive()`. + // + // If (de)compression is successful, the result is returned as a Vector, otherwise, + // an Exception is thrown. + + ctx.next_in = const_cast(reinterpret_cast(message.begin())); + ctx.avail_in = message.size(); + + kj::Vector parts(processLoop(maxSize)); + + size_t amountToAllocate = 0; + for (const auto& part : parts) { + amountToAllocate += part.size; + } + + if (addNullTerminator) { + // Add space for the null-terminator. + amountToAllocate += 1; + } + + kj::Array processedMessage = kj::heapArray(amountToAllocate); + size_t currentIndex = 0; // Current index into processedMessage. + for (const auto& part : parts) { + memcpy(&processedMessage[currentIndex], part.buffer.begin(), part.size); + // We need to use `part.size` to determine the number of useful bytes, since data after + // `part.size` is unused (and probably junk). + currentIndex += part.size; + } + + if (addNullTerminator) { + processedMessage[currentIndex++] = '\0'; + } + + KJ_ASSERT(currentIndex == processedMessage.size()); + + return kj::mv(processedMessage); + } + + void reset() { + // Resets the (de)compression context. This should only be called when the (de)compressor uses + // client/server_no_context_takeover. + switch (mode) { + case Mode::COMPRESS: { + KJ_ASSERT(deflateReset(&ctx) == Z_OK, "deflateReset() failed."); + break; + } + case Mode::DECOMPRESS: { + KJ_ASSERT(inflateReset(&ctx) == Z_OK, "inflateReset failed."); + break; + } + } + + } + + private: + Result pumpOnce() { + // Prepares Zlib's internal state for a call to deflate/inflate, then calls the relevant + // function to process the input buffer. It is assumed that the caller has already set up + // Zlib's input buffer. + // + // Since calls to deflate/inflate will process data until the input is empty, or until the + // output is full, multiple calls to `pumpOnce()` may be required to process the entire + // message. We're done processing once either `result` is `Z_STREAM_END`, or we get + // `Z_BUF_ERROR` and did not write any more output. + size_t bufSize = 4096; + Array buffer = kj::heapArray(bufSize); + ctx.next_out = buffer.begin(); + ctx.avail_out = bufSize; + + int result = Z_OK; + + switch (mode) { + case Mode::COMPRESS: + result = deflate(&ctx, Z_SYNC_FLUSH); + KJ_REQUIRE(result == Z_OK || result == Z_BUF_ERROR || result == Z_STREAM_END, + "Compression failed", result); + break; + case Mode::DECOMPRESS: + result = inflate(&ctx, Z_SYNC_FLUSH); + KJ_REQUIRE(result == Z_OK || result == Z_BUF_ERROR || result == Z_STREAM_END, + "Decompression failed", result, " with reason", ctx.msg); + break; + } + + return Result { + result, + kj::mv(buffer), + bufSize - ctx.avail_out, + }; + } + + kj::Vector processLoop(kj::Maybe maxSize) { + // Since Zlib buffers the writes, we want to continue processing until there's nothing left. + kj::Vector output; + size_t totalBytesProcessed = 0; + for (;;) { + Result result = pumpOnce(); + + auto status = result.processResult; + auto bytesProcessed = result.size; + if (bytesProcessed > 0) { + output.add(kj::mv(result)); + totalBytesProcessed += bytesProcessed; + KJ_IF_MAYBE(m, maxSize) { + // This is only non-null for `receive` calls, so we must be decompressing. We don't want + // the decompressed message to OOM us, so let's make sure it's not too big. + KJ_REQUIRE(totalBytesProcessed < *m, + "Decompressed WebSocket message is too large"); + } + } + + if ((ctx.avail_in == 0 && ctx.avail_out != 0) || status == Z_STREAM_END) { + // If we're out of input to consume, and we have space in the output buffer, then we must + // have flushed the remaining message, so we're done pumping. Alternatively, if we found a + // BFINAL deflate block, then we know the stream is completely finished. + if (status == Z_STREAM_END) { + reset(); + } + return kj::mv(output); + } + } + } + + Mode mode; + z_stream ctx = {}; + }; +#endif // KJ_HAS_ZLIB + static constexpr byte OPCODE_CONTINUATION = 0; static constexpr byte OPCODE_TEXT = 1; static constexpr byte OPCODE_BINARY = 2; @@ -2543,6 +3234,12 @@ private: kj::Own stream; kj::Maybe maskKeyGenerator; + kj::Maybe compressionConfig; + WebSocketErrorHandler& errorHandler; +#if KJ_HAS_ZLIB + kj::Maybe compressionContext; + kj::Maybe decompressionContext; +#endif // KJ_HAS_ZLIB bool hasSentClose = false; bool disconnected = false; @@ -2564,6 +3261,9 @@ private: // Perhaps it should be renamed to `blockSend` or `writeQueue`. uint fragmentOpcode = 0; + bool fragmentCompressed = false; + // For fragmented messages, was the first frame compressed? + // Note that subsequent frames of a compressed message will not set the RSV1 bit. kj::Vector> fragments; // If `fragments` is non-empty, we've already received some fragments of a message. // `fragmentOpcode` is the original opcode. @@ -2597,18 +3297,48 @@ private: Mask mask(maskKeyGenerator); - kj::Array ownMessage; - if (!mask.isZero()) { - // Sadness, we have to make a copy to apply the mask. - ownMessage = kj::heapArray(message); + bool useCompression = false; + kj::Maybe> compressedMessage; + if (opcode == OPCODE_BINARY || opcode == OPCODE_TEXT) { + // We can only compress data frames. +#if KJ_HAS_ZLIB + KJ_IF_MAYBE(config, compressionConfig) { + useCompression = true; + // Compress `message` according to `compressionConfig`s outbound parameters. + auto& compressor = KJ_ASSERT_NONNULL(compressionContext); + if (config->outboundNoContextTakeover) { + // We must reset context on each message. + compressor.reset(); + } + auto& innerMessage = compressedMessage.emplace(compressor.processMessage(message)); + if (message.size() > 0) { + KJ_ASSERT(innerMessage.asPtr().endsWith({0x00, 0x00, 0xFF, 0xFF})); + message = innerMessage.slice(0, innerMessage.size() - 4); + // Strip 0x00 0x00 0xFF 0xFF off the tail. + // See: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1 + } else { + // RFC 7692 (7.2.3.6) specifies that an empty uncompressed DEFLATE block (0x00) should be + // built if the compression library doesn't generate data when the input is empty. + message = compressedMessage.emplace(kj::heapArray({0x00})); + } + } +#endif // KJ_HAS_ZLIB + } + + kj::Array ownMessage; + if (!mask.isZero()) { + // Sadness, we have to make a copy to apply the mask. + ownMessage = kj::heapArray(message); mask.apply(ownMessage); message = ownMessage; } - sendParts[0] = sendHeader.compose(true, opcode, message.size(), mask); + sendParts[0] = sendHeader.compose(true, useCompression, opcode, message.size(), mask); sendParts[1] = message; + KJ_ASSERT(!sendHeader.hasRsv2or3(), "RSV bits 2 and 3 must be 0, as we do not currently " + "support an extension that would set these bits"); - auto promise = stream->write(sendParts); + auto promise = stream->write(sendParts).attach(kj::mv(compressedMessage)); if (!mask.isZero()) { promise = promise.attach(kj::mv(ownMessage)); } @@ -2635,9 +3365,9 @@ private: queuedPong = kj::mv(payload); } else KJ_IF_MAYBE(promise, sendingPong) { // We're still sending a previous pong. Wait for it to finish before sending ours. - sendingPong = promise->then(kj::mvCapture(payload, [this](kj::Array payload) mutable { + sendingPong = promise->then([this,payload=kj::mv(payload)]() mutable { return sendPong(kj::mv(payload)); - })); + }); } else { // We're not sending any pong currently. sendingPong = sendPong(kj::mv(payload)); @@ -2649,7 +3379,8 @@ private: return kj::READY_NOW; } - sendParts[0] = sendHeader.compose(true, OPCODE_PONG, payload.size(), Mask(maskKeyGenerator)); + sendParts[0] = sendHeader.compose(true, false, OPCODE_PONG, + payload.size(), Mask(maskKeyGenerator)); sendParts[1] = payload; return stream->write(sendParts).attach(kj::mv(payload)); } @@ -2702,19 +3433,24 @@ private: kj::Own upgradeToWebSocket( kj::Own stream, HttpInputStreamImpl& httpInput, HttpOutputStream& httpOutput, - kj::Maybe maskKeyGenerator) { + kj::Maybe maskKeyGenerator, + kj::Maybe compressionConfig = nullptr, + kj::Maybe errorHandler = nullptr) { // Create a WebSocket upgraded from an HTTP stream. auto releasedBuffer = httpInput.releaseBuffer(); return kj::heap(kj::mv(stream), maskKeyGenerator, - kj::mv(releasedBuffer.buffer), releasedBuffer.leftover, - httpOutput.flush()); + kj::mv(compressionConfig), errorHandler, + kj::mv(releasedBuffer.buffer), + releasedBuffer.leftover, httpOutput.flush()); } } // namespace kj::Own newWebSocket(kj::Own stream, - kj::Maybe maskKeyGenerator) { - return kj::heap(kj::mv(stream), maskKeyGenerator); + kj::Maybe maskKeyGenerator, + kj::Maybe compressionConfig, + kj::Maybe errorHandler) { + return kj::heap(kj::mv(stream), maskKeyGenerator, kj::mv(compressionConfig), errorHandler); } static kj::Promise pumpWebSocketLoop(WebSocket& from, WebSocket& to) { @@ -2868,13 +3604,18 @@ public: } } kj::Promise pumpTo(WebSocket& other) override { + auto onAbort = other.whenAborted() + .then([]() -> kj::Promise { + return KJ_EXCEPTION(DISCONNECTED, "WebSocket was aborted"); + }); + KJ_IF_MAYBE(s, state) { auto before = other.receivedByteCount(); return s->pumpTo(other).attach(kj::defer([this, &other, before]() { transferredBytes += other.receivedByteCount() - before; - })); + })).exclusiveJoin(kj::mv(onAbort)); } else { - return newAdaptedPromise(*this, other); + return newAdaptedPromise(*this, other).exclusiveJoin(kj::mv(onAbort)); } } @@ -3219,6 +3960,11 @@ private: canceler.release(); pipe.endState(*this); fulfiller.fulfill(); + }, [this](kj::Exception&& e) { + canceler.release(); + pipe.endState(*this); + fulfiller.reject(kj::cp(e)); + kj::throwRecoverableException(kj::mv(e)); })); } kj::Promise disconnect() override { @@ -3228,6 +3974,11 @@ private: pipe.endState(*this); fulfiller.fulfill(); return pipe.disconnect(); + }, [this](kj::Exception&& e) { + canceler.release(); + pipe.endState(*this); + fulfiller.reject(kj::cp(e)); + kj::throwRecoverableException(kj::mv(e)); })); } kj::Maybe> tryPumpFrom(WebSocket& other) override { @@ -3236,6 +3987,11 @@ private: canceler.release(); pipe.endState(*this); fulfiller.fulfill(); + }, [this](kj::Exception&& e) { + canceler.release(); + pipe.endState(*this); + fulfiller.reject(kj::cp(e)); + kj::throwRecoverableException(kj::mv(e)); })); } @@ -3403,8 +4159,846 @@ WebSocketPipe newWebSocketPipe() { } // ======================================================================================= +class AsyncIoStreamWithInitialBuffer final: public kj::AsyncIoStream { + // An AsyncIoStream implementation that accepts an initial buffer of data + // to be read out first, and is optionally capable of deferring writes + // until a given waitBeforeSend promise is fulfilled. + // + // Instances are created with a leftoverBackingBuffer (a kj::Array) + // and a leftover kj::ArrayPtr that provides a view into the backing + // buffer representing the queued data that is pending to be read. Calling + // tryRead will consume the data from the leftover first. Once leftover has + // been fully consumed, reads will defer to the underlying stream. +public: + AsyncIoStreamWithInitialBuffer(kj::Own stream, + kj::Array leftoverBackingBuffer, + kj::ArrayPtr leftover) + : stream(kj::mv(stream)), + leftoverBackingBuffer(kj::mv(leftoverBackingBuffer)), + leftover(leftover) {} + + void shutdownWrite() override { + stream->shutdownWrite(); + } + + // AsyncInputStream + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + KJ_REQUIRE(maxBytes >= minBytes); + auto destination = static_cast(buffer); + + // If there are at least minBytes available in the leftover buffer... + if (leftover.size() >= minBytes) { + // We are going to immediately read up to maxBytes from the leftover buffer... + auto bytesToCopy = kj::min(maxBytes, leftover.size()); + memcpy(destination, leftover.begin(), bytesToCopy); + leftover = leftover.slice(bytesToCopy, leftover.size()); + + // If we've consumed all of the data in the leftover buffer, go ahead and free it. + if (leftover.size() == 0) { + leftoverBackingBuffer = nullptr; + } + + return bytesToCopy; + } else { + // We know here that leftover.size() is less than minBytes, but it might not + // be zero. Copy everything from leftover into the destination buffer then read + // the rest from the underlying stream. + auto bytesToCopy = leftover.size(); + KJ_DASSERT(bytesToCopy < minBytes); + + if (bytesToCopy > 0) { + memcpy(destination, leftover.begin(), bytesToCopy); + leftoverBackingBuffer = nullptr; + minBytes -= bytesToCopy; + maxBytes -= bytesToCopy; + KJ_DASSERT(minBytes >= 1); + KJ_DASSERT(maxBytes >= minBytes); + } + + return stream->tryRead(destination + bytesToCopy, minBytes, maxBytes) + .then([bytesToCopy](size_t amount) { return amount + bytesToCopy; }); + } + } + + Maybe tryGetLength() override { + // For a CONNECT pipe, we have no idea how much data there is going to be. + return nullptr; + } + + kj::Promise pumpTo(AsyncOutputStream& output, + uint64_t amount = kj::maxValue) override { + return pumpLoop(output, amount, 0); + } + + kj::Maybe> tryPumpFrom(AsyncInputStream& input, + uint64_t amount = kj::maxValue) override { + return input.pumpTo(*stream, amount); + } + + // AsyncOutputStream + Promise write(const void* buffer, size_t size) override { + return stream->write(buffer, size); + } + + Promise write(ArrayPtr> pieces) override { + return stream->write(pieces); + } + + Promise whenWriteDisconnected() override { + return stream->whenWriteDisconnected(); + } +private: + + kj::Promise pumpLoop( + kj::AsyncOutputStream& output, + uint64_t remaining, + uint64_t total) { + // If there is any data remaining in the leftover queue, we'll write it out first to output. + if (leftover.size() > 0) { + auto bytesToWrite = kj::min(leftover.size(), remaining); + return output.write(leftover.begin(), bytesToWrite).then( + [this, &output, remaining, total, bytesToWrite]() mutable -> kj::Promise { + leftover = leftover.slice(bytesToWrite, leftover.size()); + // If the leftover buffer has been fully consumed, go ahead and free it now. + if (leftover.size() == 0) { + leftoverBackingBuffer = nullptr; + } + remaining -= bytesToWrite; + total += bytesToWrite; + + if (remaining == 0) { + return total; + } + return pumpLoop(output, remaining, total); + }); + } else { + // Otherwise, we are just going to defer to stream's pumpTo, making sure to + // account for the total amount we've already written from the leftover queue. + return stream->pumpTo(output, remaining).then([total](auto read) { + return total + read; + }); + } + }; + + kj::Own stream; + kj::Array leftoverBackingBuffer; + kj::ArrayPtr leftover; +}; + +class AsyncIoStreamWithGuards final: public kj::AsyncIoStream, + private kj::TaskSet::ErrorHandler { + // This AsyncIoStream adds separate kj::Promise guards to both the input and output, + // delaying reads and writes until each relevant guard is resolved. + // + // When the read guard promise resolves, it may provide a released buffer that will + // be read out first. + // The primary use case for this impl is to support pipelined CONNECT calls which + // optimistically allow outbound writes to happen while establishing the CONNECT + // tunnel has not yet been completed. If the guard promise rejects, the stream + // is permanently errored and existing pending calls (reads and writes) are canceled. +public: + AsyncIoStreamWithGuards( + kj::Own inner, + kj::Promise> readGuard, + kj::Promise writeGuard) + : inner(kj::mv(inner)), + readGuard(handleReadGuard(kj::mv(readGuard))), + writeGuard(handleWriteGuard(kj::mv(writeGuard))), + tasks(*this) {} + + // AsyncInputStream + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + if (readGuardReleased) { + return inner->tryRead(buffer, minBytes, maxBytes); + } + return readGuard.addBranch().then([this, buffer, minBytes, maxBytes] { + return inner->tryRead(buffer, minBytes, maxBytes); + }); + } + + Maybe tryGetLength() override { + return nullptr; + } + + kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount = kj::maxValue) override { + if (readGuardReleased) { + return inner->pumpTo(output, amount); + } + return readGuard.addBranch().then([this, &output, amount] { + return inner->pumpTo(output, amount); + }); + } + + // AsyncOutputStream + + void shutdownWrite() override { + if (writeGuardReleased) { + inner->shutdownWrite(); + } else { + tasks.add(writeGuard.addBranch().then([this]() { inner->shutdownWrite(); })); + } + } + + kj::Maybe> tryPumpFrom(AsyncInputStream& input, + uint64_t amount = kj::maxValue) override { + if (writeGuardReleased) { + return input.pumpTo(*inner, amount); + } else { + return writeGuard.addBranch().then([this,&input,amount]() { + return input.pumpTo(*inner, amount); + }); + } + } + + Promise write(const void* buffer, size_t size) override { + if (writeGuardReleased) { + return inner->write(buffer, size); + } else { + return writeGuard.addBranch().then([this,buffer,size]() { + return inner->write(buffer, size); + }); + } + } + + Promise write(ArrayPtr> pieces) override { + if (writeGuardReleased) { + return inner->write(pieces); + } else { + return writeGuard.addBranch().then([this, pieces]() { + return inner->write(pieces); + }); + } + } + + Promise whenWriteDisconnected() override { + if (writeGuardReleased) { + return inner->whenWriteDisconnected(); + } else { + return writeGuard.addBranch().then([this]() { + return inner->whenWriteDisconnected(); + }, [](kj::Exception&& e) mutable -> kj::Promise { + if (e.getType() == kj::Exception::Type::DISCONNECTED) { + return kj::READY_NOW; + } else { + return kj::mv(e); + } + }); + } + } + +private: + kj::Own inner; + kj::ForkedPromise readGuard; + kj::ForkedPromise writeGuard; + bool readGuardReleased = false; + bool writeGuardReleased = false; + kj::TaskSet tasks; + // Set of tasks used to call `shutdownWrite` after write guard is released. + + void taskFailed(kj::Exception&& exception) override { + // This `taskFailed` callback is only used when `shutdownWrite` is being called. Because we + // don't care about DISCONNECTED exceptions when `shutdownWrite` is called we ignore this + // class of exceptions here. + if (exception.getType() != kj::Exception::Type::DISCONNECTED) { + KJ_LOG(ERROR, exception); + } + } + + kj::ForkedPromise handleWriteGuard(kj::Promise guard) { + return guard.then([this]() { + writeGuardReleased = true; + }).fork(); + } + + kj::ForkedPromise handleReadGuard( + kj::Promise> guard) { + return guard.then([this](kj::Maybe buffer) mutable { + readGuardReleased = true; + KJ_IF_MAYBE(b, buffer) { + if (b->leftover.size() > 0) { + // We only need to replace the inner stream if a non-empty buffer is provided. + inner = heap( + kj::mv(inner), + kj::mv(b->buffer), b->leftover); + } + } + }).fork(); + } +}; + +// ======================================================================================= + +namespace _ { // private implementation details + +kj::ArrayPtr splitNext(kj::ArrayPtr& cursor, char delimiter) { + // Consumes and returns the next item in a delimited list. + // + // If a delimiter is found: + // - `cursor` is updated to point to the rest of the string after the delimiter. + // - The text before the delimiter is returned. + // If no delimiter is found: + // - `cursor` is updated to an empty string. + // - The text that had been in `cursor` is returned. + // + // (It's up to the caller to stop the loop once `cursor` is empty.) + KJ_IF_MAYBE(index, cursor.findFirst(delimiter)) { + auto part = cursor.slice(0, *index); + cursor = cursor.slice(*index + 1, cursor.size()); + return part; + } + kj::ArrayPtr result(kj::mv(cursor)); + cursor = nullptr; + + return result; +} + +void stripLeadingAndTrailingSpace(ArrayPtr& str) { + // Remove any leading/trailing spaces from `str`, modifying it in-place. + while (str.size() > 0 && (str[0] == ' ' || str[0] == '\t')) { + str = str.slice(1, str.size()); + } + while (str.size() > 0 && (str.back() == ' ' || str.back() == '\t')) { + str = str.slice(0, str.size() - 1); + } +} + +kj::Vector> splitParts(kj::ArrayPtr input, char delim) { + // Given a string `input` and a delimiter `delim`, split the string into a vector of substrings, + // separated by the delimiter. Note that leading/trailing whitespace is stripped from each element. + kj::Vector> parts; + + while (input.size() != 0) { + auto part = splitNext(input, delim); + stripLeadingAndTrailingSpace(part); + parts.add(kj::mv(part)); + } + + return parts; +} + +kj::Array toKeysAndVals(const kj::ArrayPtr>& params) { + // Given a collection of parameters (a single offer), parse the parameters into + // pairs. If the parameter contains an `=`, we set the `key` to everything before, and the `value` + // to everything after. Otherwise, we set the `key` to be the entire parameter. + // Either way, both the key and value (if it exists) are stripped of leading & trailing whitespace. + auto result = kj::heapArray(params.size()); + size_t count = 0; + for (const auto& param : params) { + kj::ArrayPtr key; + kj::Maybe> value; + + KJ_IF_MAYBE(index, param.findFirst('=')) { + // Found '=' so we have a value. + key = param.slice(0, *index); + stripLeadingAndTrailingSpace(key); + value = param.slice(*index + 1, param.size()); + KJ_IF_MAYBE(v, value) { + stripLeadingAndTrailingSpace(*v); + } + } else { + key = kj::mv(param); + } + + result[count].key = kj::mv(key); + result[count].val = kj::mv(value); + ++count; + } + return kj::mv(result); +} + +struct ParamType { + enum { CLIENT, SERVER } side; + enum { NO_CONTEXT_TAKEOVER, MAX_WINDOW_BITS } property; +}; + +inline kj::Maybe parseKeyName(kj::ArrayPtr& key) { + // Returns a `ParamType` struct if the `key` is valid and nullptr if invalid. + + if (key == "client_no_context_takeover"_kj) { + return ParamType { ParamType::CLIENT, ParamType::NO_CONTEXT_TAKEOVER }; + } else if (key == "server_no_context_takeover"_kj) { + return ParamType { ParamType::SERVER, ParamType::NO_CONTEXT_TAKEOVER }; + } else if (key == "client_max_window_bits"_kj) { + return ParamType { ParamType::CLIENT, ParamType::MAX_WINDOW_BITS }; + } else if (key == "server_max_window_bits"_kj) { + return ParamType { ParamType::SERVER, ParamType::MAX_WINDOW_BITS }; + } + return nullptr; +} + +kj::Maybe populateUnverifiedConfig(kj::Array& params) { + // Given a collection of pairs, attempt to populate an `UnverifiedConfig` struct. + // If the struct cannot be populated, we return null. + // + // This function populates the struct with what it finds, it does not perform bounds checking or + // concern itself with valid `Value`s (so long as the `Value` is non-empty). + // + // The following issues would prevent a struct from being populated: + // Key issues: + // - `Key` is invalid (see `parseKeyName()`). + // - `Key` is repeated. + // Value issues: + // - Got a `Value` when none was expected (only the `max_window_bits` parameters expect values). + // - Got an empty `Value` (0 characters, or all whitespace characters). + + if (params.size() > 4) { + // We expect 4 `Key`s at most, having more implies repeats/invalid keys are present. + return nullptr; + } + + UnverifiedConfig config; + + for (auto& param : params) { + KJ_IF_MAYBE(paramType, parseKeyName(param.key)) { + // `Key` is valid, but we still want to check for repeats. + const auto& side = paramType->side; + const auto& property = paramType->property; + + if (property == ParamType::NO_CONTEXT_TAKEOVER) { + auto& takeOverSetting = (side == ParamType::CLIENT) ? + config.clientNoContextTakeover : config.serverNoContextTakeover; + + if (takeOverSetting == true) { + // This `Key` is a repeat; invalid config. + return nullptr; + } + + if (param.val != nullptr) { + // The `x_no_context_takeover` parameter shouldn't have a value; invalid config. + return nullptr; + } + + takeOverSetting = true; + } else if (property == ParamType::MAX_WINDOW_BITS) { + auto& maxBitsSetting = + (side == ParamType::CLIENT) ? config.clientMaxWindowBits : config.serverMaxWindowBits; + + if (maxBitsSetting != nullptr) { + // This `Key` is a repeat; invalid config. + return nullptr; + } + + KJ_IF_MAYBE(value, param.val) { + if (value->size() == 0) { + // This is equivalent to `x_max_window_bits=`, since we got an "=" we expected a token + // to follow. + return nullptr; + } + maxBitsSetting = param.val; + } else { + // We know we got this `max_window_bits` parameter in a Request/Response, and we also know + // that it didn't include an "=" (otherwise the value wouldn't be null). + // It's important to retain the information that the parameter was received *without* a + // corresponding value, as this may determine whether the offer is valid or not. + // + // To retain this information, we'll set `maxBitsSetting` to be an empty ArrayPtr so this + // can be dealt with properly later. + maxBitsSetting = ArrayPtr(); + } + } + } else { + // Invalid parameter. + return nullptr; + } + } + return kj::mv(config); +} + +kj::Maybe validateCompressionConfig(UnverifiedConfig&& config, + bool isAgreement) { + // Verifies that the `config` is valid depending on whether we're validating a Request (offer) or + // a Response (agreement). This essentially consumes the `UnverifiedConfig` and converts it into a + // `CompressionParameters` struct. + CompressionParameters result; + + KJ_IF_MAYBE(serverBits, config.serverMaxWindowBits) { + if (serverBits->size() == 0) { + // This means `server_max_window_bits` was passed without a value. Since a value is required, + // this config is invalid. + return nullptr; + } else { + KJ_IF_MAYBE(bits, kj::str(*serverBits).tryParseAs()) { + if (*bits < 8 || 15 < *bits) { + // Out of range -- invalid. + return nullptr; + } + if (isAgreement) { + result.inboundMaxWindowBits = *bits; + } else { + result.outboundMaxWindowBits = *bits; + } + } else { + // Invalid ABNF, expected 1*DIGIT. + return nullptr; + } + } + } + + KJ_IF_MAYBE(clientBits, config.clientMaxWindowBits) { + if (clientBits->size() == 0) { + if (!isAgreement) { + // `client_max_window_bits` does not need to have a value in an offer, let's set it to 15 + // to get the best level of compression. + result.inboundMaxWindowBits = 15; + } else { + // `client_max_window_bits` must have a value in a Response. + return nullptr; + } + } else { + KJ_IF_MAYBE(bits, kj::str(*clientBits).tryParseAs()) { + if (*bits < 8 || 15 < *bits) { + // Out of range -- invalid. + return nullptr; + } + if (isAgreement) { + result.outboundMaxWindowBits = *bits; + } else { + result.inboundMaxWindowBits = *bits; + } + } else { + // Invalid ABNF, expected 1*DIGIT. + return nullptr; + } + } + } + + if (isAgreement) { + result.outboundNoContextTakeover = config.clientNoContextTakeover; + result.inboundNoContextTakeover = config.serverNoContextTakeover; + } else { + result.inboundNoContextTakeover = config.clientNoContextTakeover; + result.outboundNoContextTakeover = config.serverNoContextTakeover; + } + return kj::mv(result); +} + +inline kj::Maybe tryExtractParameters( + kj::Vector>& configuration, + bool isAgreement) { + // If the `configuration` is structured correctly and has no invalid parameters/values, we will + // return a populated `CompressionParameters` struct. + if (configuration.size() == 1) { + // Plain `permessage-deflate`. + return CompressionParameters{}; + } + auto params = configuration.slice(1, configuration.size()); + auto keyMaybeValuePairs = toKeysAndVals(params); + // Parse parameter strings into parameter[=value] pairs. + auto maybeUnverified = populateUnverifiedConfig(keyMaybeValuePairs); + KJ_IF_MAYBE(unverified, maybeUnverified) { + // Parsing succeeded, i.e. the parameter (`key`) names are valid and we don't have + // values for `x_no_context_takeover` parameters (the configuration is structured correctly). + // All that's left is to check the `x_max_window_bits` values (if any are present). + KJ_IF_MAYBE(validConfig, validateCompressionConfig(kj::mv(*unverified), isAgreement)) { + return kj::mv(*validConfig); + } + } + return nullptr; +} + +kj::Vector findValidExtensionOffers(StringPtr offers) { + // A function to be called by the client that wants to offer extensions through + // `Sec-WebSocket-Extensions`. This function takes the value of the header (a string) and + // populates a Vector of all the valid offers. + kj::Vector result; + + auto extensions = splitParts(offers, ','); + + for (const auto& offer : extensions) { + auto splitOffer = splitParts(offer, ';'); + if (splitOffer.front() != "permessage-deflate"_kj) { + continue; + } + KJ_IF_MAYBE(validated, tryExtractParameters(splitOffer, false)) { + // We need to swap the inbound/outbound properties since `tryExtractParameters` thinks we're + // parsing as the server (`isAgreement` is false). + auto tempCtx = validated->inboundNoContextTakeover; + validated->inboundNoContextTakeover = validated->outboundNoContextTakeover; + validated->outboundNoContextTakeover = tempCtx; + auto tempWindow = validated->inboundMaxWindowBits; + validated->inboundMaxWindowBits = validated->outboundMaxWindowBits; + validated->outboundMaxWindowBits = tempWindow; + result.add(kj::mv(*validated)); + } + } + + return kj::mv(result); +} + +kj::String generateExtensionRequest(const ArrayPtr& extensions) { + // Build the `Sec-WebSocket-Extensions` request from the validated parameters. + constexpr auto EXT = "permessage-deflate"_kj; + auto offers = kj::heapArray(extensions.size()); + size_t i = 0; + for (const auto& offer : extensions) { + offers[i] = kj::str(EXT); + if (offer.outboundNoContextTakeover) { + offers[i] = kj::str(offers[i], "; client_no_context_takeover"); + } + if (offer.inboundNoContextTakeover) { + offers[i] = kj::str(offers[i], "; server_no_context_takeover"); + } + if (offer.outboundMaxWindowBits != nullptr) { + auto w = KJ_ASSERT_NONNULL(offer.outboundMaxWindowBits); + offers[i] = kj::str(offers[i], "; client_max_window_bits=", w); + } + if (offer.inboundMaxWindowBits != nullptr) { + auto w = KJ_ASSERT_NONNULL(offer.inboundMaxWindowBits); + offers[i] = kj::str(offers[i], "; server_max_window_bits=", w); + } + ++i; + } + return kj::strArray(offers, ", "); +} + +kj::Maybe tryParseExtensionOffers(StringPtr offers) { + // Given a string of offers, accept the first valid offer by returning a `CompressionParameters` + // struct. If there are no valid offers, return `nullptr`. + auto splitOffers = splitParts(offers, ','); + + for (const auto& offer : splitOffers) { + auto splitOffer = splitParts(offer, ';'); + + if (splitOffer.front() != "permessage-deflate"_kj) { + // Extension token was invalid. + continue; + } + KJ_IF_MAYBE(config, tryExtractParameters(splitOffer, false)) { + return kj::mv(*config); + } + } + return nullptr; +} + +kj::Maybe tryParseAllExtensionOffers(StringPtr offers, + CompressionParameters manualConfig) { + // Similar to `tryParseExtensionOffers()`, however, this function is called when parsing in + // `MANUAL_COMPRESSION` mode. In some cases, the server's configuration might not support the + // `server_no_context_takeover` or `server_max_window_bits` parameters. Essentially, this function + // will look at all the client's offers, and accept the first one that it can support. + // + // We differentiate these functions because in `AUTOMATIC_COMPRESSION` mode, KJ can support these + // server restricting compression parameters. + auto splitOffers = splitParts(offers, ','); + + for (const auto& offer : splitOffers) { + auto splitOffer = splitParts(offer, ';'); + + if (splitOffer.front() != "permessage-deflate"_kj) { + // Extension token was invalid. + continue; + } + KJ_IF_MAYBE(config, tryExtractParameters(splitOffer, false)) { + KJ_IF_MAYBE(finalConfig, compareClientAndServerConfigs(*config, manualConfig)) { + // Found a compatible configuration between the server's config and client's offer. + return kj::mv(*finalConfig); + } + } + } + return nullptr; +} + +kj::Maybe compareClientAndServerConfigs(CompressionParameters requestConfig, + CompressionParameters manualConfig) { + // We start from the `manualConfig` and go through a series of filters to get a compression + // configuration that both the client and the server can agree upon. If no agreement can be made, + // we return null. + + CompressionParameters acceptedParameters = manualConfig; + + // We only need to modify `client_no_context_takeover` and `server_no_context_takeover` when + // `manualConfig` doesn't include them. + if (manualConfig.inboundNoContextTakeover == false) { + acceptedParameters.inboundNoContextTakeover = false; + } + + if (manualConfig.outboundNoContextTakeover == false) { + acceptedParameters.outboundNoContextTakeover = false; + if (requestConfig.outboundNoContextTakeover == true) { + // The client has told the server to not use context takeover. This is not a "hint", + // rather it is a restriction on the server's configuration. If the server does not support + // the configuration, it must reject the offer. + return nullptr; + } + } + + // client_max_window_bits + if (requestConfig.inboundMaxWindowBits != nullptr && + manualConfig.inboundMaxWindowBits != nullptr) { + // We want `min(requestConfig, manualConfig)` in this case. + auto reqBits = KJ_ASSERT_NONNULL(requestConfig.inboundMaxWindowBits); + auto manualBits = KJ_ASSERT_NONNULL(manualConfig.inboundMaxWindowBits); + if (reqBits < manualBits) { + acceptedParameters.inboundMaxWindowBits = reqBits; + } + } else { + // We will not reply with `client_max_window_bits`. + acceptedParameters.inboundMaxWindowBits = nullptr; + } + + // server_max_window_bits + if (manualConfig.outboundMaxWindowBits != nullptr) { + auto manualBits = KJ_ASSERT_NONNULL(manualConfig.outboundMaxWindowBits); + if (requestConfig.outboundMaxWindowBits != nullptr) { + // We want `min(requestConfig, manualConfig)` in this case. + auto reqBits = KJ_ASSERT_NONNULL(requestConfig.outboundMaxWindowBits); + if (reqBits < manualBits) { + acceptedParameters.outboundMaxWindowBits = reqBits; + } + } + } else { + acceptedParameters.outboundMaxWindowBits = nullptr; + if (requestConfig.outboundMaxWindowBits != nullptr) { + // The client has told the server to use `server_max_window_bits`. This is not a "hint", + // rather it is a restriction on the server's configuration. If the server does not support + // the configuration, it must reject the offer. + return nullptr; + } + } + return acceptedParameters; +} + +kj::String generateExtensionResponse(const CompressionParameters& parameters) { + // Build the `Sec-WebSocket-Extensions` response from the agreed parameters. + kj::String response = kj::str("permessage-deflate"); + if (parameters.inboundNoContextTakeover) { + response = kj::str(response, "; client_no_context_takeover"); + } + if (parameters.outboundNoContextTakeover) { + response = kj::str(response, "; server_no_context_takeover"); + } + if (parameters.inboundMaxWindowBits != nullptr) { + auto w = KJ_REQUIRE_NONNULL(parameters.inboundMaxWindowBits); + response = kj::str(response, "; client_max_window_bits=", w); + } + if (parameters.outboundMaxWindowBits != nullptr) { + auto w = KJ_REQUIRE_NONNULL(parameters.outboundMaxWindowBits); + response = kj::str(response, "; server_max_window_bits=", w); + } + return kj::mv(response); +} + +kj::OneOf tryParseExtensionAgreement( + const Maybe& clientOffer, + StringPtr agreedParameters) { + // Like `tryParseExtensionOffers`, but called by the client when parsing the server's Response. + // If the client must decline the agreement, we want to provide some details about what went wrong + // (since the client has to fail the connection). + constexpr auto FAILURE = "Server failed WebSocket handshake: "_kj; + auto e = KJ_EXCEPTION(FAILED); + + if (clientOffer == nullptr) { + // We've received extensions when we did not send any in the first place. + e.setDescription( + kj::str(FAILURE, "added Sec-WebSocket-Extensions when client did not offer any.")); + return kj::mv(e); + } + + auto offers = splitParts(agreedParameters, ','); + if (offers.size() != 1) { + constexpr auto EXPECT = "expected exactly one extension (permessage-deflate) but received " + "more than one."_kj; + e.setDescription(kj::str(FAILURE, EXPECT)); + return kj::mv(e); + } + auto splitOffer = splitParts(offers.front(), ';'); + + if (splitOffer.front() != "permessage-deflate"_kj) { + e.setDescription(kj::str(FAILURE, "response included a Sec-WebSocket-Extensions value that was " + "not permessage-deflate.")); + return kj::mv(e); + } + + // Verify the parameters of our single extension, and compare it with the clients original offer. + KJ_IF_MAYBE(config, tryExtractParameters(splitOffer, true)) { + const auto& client = KJ_ASSERT_NONNULL(clientOffer); + // The server might have ignored the client's hints regarding its compressor's configuration. + // That's fine, but as the client, we still want to use those outbound compression parameters. + if (config->outboundMaxWindowBits == nullptr) { + config->outboundMaxWindowBits = client.outboundMaxWindowBits; + } else KJ_IF_MAYBE(value, client.outboundMaxWindowBits) { + if (*value < KJ_ASSERT_NONNULL(config->outboundMaxWindowBits)) { + // If the client asked for a value smaller than what the server responded with, use the + // value that the client originally specified. + config->outboundMaxWindowBits = *value; + } + } + if (config->outboundNoContextTakeover == false) { + config->outboundNoContextTakeover = client.outboundNoContextTakeover; + } + return kj::mv(*config); + } + + // There was a problem parsing the server's `Sec-WebSocket-Extensions` response. + e.setDescription(kj::str(FAILURE, "the Sec-WebSocket-Extensions header in the Response included " + "an invalid value.")); + return kj::mv(e); +} +} // namespace _ (private) namespace { +class NullInputStream final: public kj::AsyncInputStream { +public: + NullInputStream(kj::Maybe expectedLength = size_t(0)) + : expectedLength(expectedLength) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return constPromise(); + } + + kj::Maybe tryGetLength() override { + return expectedLength; + } + + kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return constPromise(); + } + +private: + kj::Maybe expectedLength; +}; + +class NullOutputStream final: public kj::AsyncOutputStream { +public: + Promise write(const void* buffer, size_t size) override { + return kj::READY_NOW; + } + Promise write(ArrayPtr> pieces) override { + return kj::READY_NOW; + } + Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. +}; + +class NullIoStream final: public kj::AsyncIoStream { +public: + void shutdownWrite() override {} + + Promise write(const void* buffer, size_t size) override { + return kj::READY_NOW; + } + Promise write(ArrayPtr> pieces) override { + return kj::READY_NOW; + } + Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return constPromise(); + } + + kj::Maybe tryGetLength() override { + return kj::Maybe((uint64_t)0); + } + + kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return constPromise(); + } +}; class HttpClientImpl final: public HttpClient, private HttpClientErrorHandler { @@ -3542,6 +5136,36 @@ public: connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_VERSION] = "13"; connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_KEY] = keyBase64; + kj::Maybe offeredExtensions; + kj::Maybe clientOffer; + kj::Vector extensions; + auto compressionMode = settings.webSocketCompressionMode; + + if (compressionMode == HttpClientSettings::MANUAL_COMPRESSION) { + KJ_IF_MAYBE(value, headers.get(HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + // Strip all `Sec-WebSocket-Extensions` except for `permessage-deflate`. + extensions = _::findValidExtensionOffers(*value); + } + } else if (compressionMode == HttpClientSettings::AUTOMATIC_COMPRESSION) { + // If AUTOMATIC_COMPRESSION is enabled, we send `Sec-WebSocket-Extensions: permessage-deflate` + // to the server and ignore the `headers` provided by the caller. + extensions.add(CompressionParameters()); + } + + if (extensions.size() > 0) { + clientOffer = extensions.front(); + // We hold on to a copy of the client's most preferred offer so even if the server + // ignores `client_no_context_takeover` or `client_max_window_bits`, we can still refer to + // the original offer made by the client (thereby allowing the client to use these parameters). + // + // It's safe to ignore the remaining offers because: + // 1. Offers are ordered by preference. + // 2. `client_x` parameters are hints to the server and do not result in rejections, so the + // client is likely to put them in every offer anyways. + connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_EXTENSIONS] = + offeredExtensions.emplace(_::generateExtensionRequest(extensions.asPtr())); + } + httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders)); // No entity-body. @@ -3550,7 +5174,7 @@ public: auto id = ++counter; return httpInput.readResponseHeaders() - .then([this,id,keyBase64 = kj::mv(keyBase64)]( + .then([this,id,keyBase64 = kj::mv(keyBase64),clientOffer = kj::mv(clientOffer)]( HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) -> HttpClient::WebSocketResponse { KJ_SWITCH_ONEOF(responseOrProtocolError) { @@ -3592,11 +5216,27 @@ public: }); } + kj::Maybe compressionParameters; + if (settings.webSocketCompressionMode != HttpClientSettings::NO_COMPRESSION) { + KJ_IF_MAYBE(agreedParameters, responseHeaders.get( + HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + + auto parseResult = _::tryParseExtensionAgreement(clientOffer, + *agreedParameters); + if (parseResult.is()) { + return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError({ + 502, "Bad Gateway", parseResult.get().getDescription(), nullptr}); + } + compressionParameters.emplace(kj::mv(parseResult.get())); + } + } + return { response.statusCode, response.statusText, &httpInput.getHeaders(), - upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, settings.entropySource), + upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, settings.entropySource, + kj::mv(compressionParameters)), }; } else { upgraded = false; @@ -3629,6 +5269,85 @@ public: }); } + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) override { + KJ_REQUIRE(!upgraded, + "can't make further requests on this HttpClient because it has been or is in the process " + "of being upgraded"); + KJ_REQUIRE(!closed, + "this HttpClient's connection has been closed by the server or due to an error"); + KJ_REQUIRE(httpOutput.canReuse(), + "can't start new request until previous request body has been fully written"); + + if (settings.useTls) { + KJ_UNIMPLEMENTED("This HttpClient does not support TLS."); + } + + closeWatcherTask = nullptr; + + // Mark upgraded for now even though the tunnel could fail, because we can't allow pipelined + // requests in the meantime. + upgraded = true; + + kj::StringPtr connectionHeaders[HttpHeaders::CONNECTION_HEADERS_COUNT]; + + httpOutput.writeHeaders(headers.serializeConnectRequest(host, connectionHeaders)); + + auto id = ++counter; + + auto split = httpInput.readResponseHeaders().then( + [this, id](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) mutable + -> kj::Tuple, + kj::Promise>> { + KJ_SWITCH_ONEOF(responseOrProtocolError) { + KJ_CASE_ONEOF(response, HttpHeaders::Response) { + auto& responseHeaders = httpInput.getHeaders(); + if (response.statusCode < 200 || response.statusCode >= 300) { + // Any statusCode that is not in the 2xx range in interpreted + // as an HTTP response. Any status code in the 2xx range is + // interpreted as a successful CONNECT response. + closed = true; + return kj::tuple(ConnectRequest::Status( + response.statusCode, + kj::str(response.statusText), + kj::heap(responseHeaders.clone()), + httpInput.getEntityBody( + HttpInputStreamImpl::RESPONSE, + HttpConnectMethod(), + response.statusCode, + responseHeaders)), + KJ_EXCEPTION(DISCONNECTED, "the connect request was rejected")); + } + KJ_ASSERT(counter == id); + return kj::tuple(ConnectRequest::Status( + response.statusCode, + kj::str(response.statusText), + kj::heap(responseHeaders.clone()) + ), kj::Maybe(httpInput.releaseBuffer())); + } + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + closed = true; + auto response = handleProtocolError(protocolError); + return kj::tuple(ConnectRequest::Status( + response.statusCode, + kj::str(response.statusText), + kj::heap(response.headers->clone()), + kj::mv(response.body) + ), KJ_EXCEPTION(DISCONNECTED, "the connect request errored")); + } + } + KJ_UNREACHABLE; + }).split(); + + return ConnectRequest { + kj::mv(kj::get<0>(split)), // Promise for the result + heap( + kj::mv(ownStream), + kj::mv(kj::get<1>(split)) /* read guard (Promise for the ReleasedBuffer) */, + httpOutput.flush() /* write guard (void Promise) */) + }; + } + private: HttpInputStreamImpl httpInput; HttpOutputStream httpOutput; @@ -3696,7 +5415,8 @@ kj::Promise HttpClient::openWebSocket( }); } -kj::Promise> HttpClient::connect(kj::StringPtr host) { +HttpClient::ConnectRequest HttpClient::connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) { KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpClient"); } @@ -3722,6 +5442,169 @@ HttpClient::WebSocketResponse HttpClientErrorHandler::handleWebSocketProtocolErr }; } +kj::Exception WebSocketErrorHandler::handleWebSocketProtocolError( + WebSocket::ProtocolError protocolError) { + return KJ_EXCEPTION(FAILED, kj::str("code=", protocolError.statusCode, + ": ", protocolError.description)); +} + +class PausableReadAsyncIoStream::PausableRead { +public: + PausableRead( + kj::PromiseFulfiller& fulfiller, PausableReadAsyncIoStream& parent, + void* buffer, size_t minBytes, size_t maxBytes) + : fulfiller(fulfiller), parent(parent), + operationBuffer(buffer), operationMinBytes(minBytes), operationMaxBytes(maxBytes), + innerRead(parent.tryReadImpl(operationBuffer, operationMinBytes, operationMaxBytes).then( + [&fulfiller](size_t size) mutable -> kj::Promise { + fulfiller.fulfill(kj::mv(size)); + return kj::READY_NOW; + }, [&fulfiller](kj::Exception&& err) { + fulfiller.reject(kj::mv(err)); + })) { + KJ_ASSERT(parent.maybePausableRead == nullptr); + parent.maybePausableRead = *this; + } + + ~PausableRead() noexcept(false) { + parent.maybePausableRead = nullptr; + } + + void pause() { + innerRead = nullptr; + } + + void unpause() { + innerRead = parent.tryReadImpl(operationBuffer, operationMinBytes, operationMaxBytes).then( + [this](size_t size) -> kj::Promise { + fulfiller.fulfill(kj::mv(size)); + return kj::READY_NOW; + }, [this](kj::Exception&& err) { + fulfiller.reject(kj::mv(err)); + }); + } + + void reject(kj::Exception&& exc) { + fulfiller.reject(kj::mv(exc)); + } +private: + kj::PromiseFulfiller& fulfiller; + PausableReadAsyncIoStream& parent; + + void* operationBuffer; + size_t operationMinBytes; + size_t operationMaxBytes; + // The parameters of the current tryRead call. Used to unpause a paused read. + + kj::Promise innerRead; + // The current pending read. +}; + +_::Deferred> PausableReadAsyncIoStream::trackRead() { + KJ_REQUIRE(!currentlyReading, "only one read is allowed at any one time"); + currentlyReading = true; + return kj::defer>([this]() { currentlyReading = false; }); +} + +_::Deferred> PausableReadAsyncIoStream::trackWrite() { + KJ_REQUIRE(!currentlyWriting, "only one write is allowed at any one time"); + currentlyWriting = true; + return kj::defer>([this]() { currentlyWriting = false; }); +} + +kj::Promise PausableReadAsyncIoStream::tryRead( + void* buffer, size_t minBytes, size_t maxBytes) { + return kj::newAdaptedPromise(*this, buffer, minBytes, maxBytes); +} + +kj::Promise PausableReadAsyncIoStream::tryReadImpl( + void* buffer, size_t minBytes, size_t maxBytes) { + // Hack: evalNow used here because `newAdaptedPromise` has a bug. We may need to change + // `PromiseDisposer::alloc` to not be `noexcept` but in order to do so we'll need to benchmark + // its performance. + return kj::evalNow([&]() -> kj::Promise { + return inner->tryRead(buffer, minBytes, maxBytes).attach(trackRead()); + }); +} + +kj::Maybe PausableReadAsyncIoStream::tryGetLength() { + return inner->tryGetLength(); +} + +kj::Promise PausableReadAsyncIoStream::pumpTo( + kj::AsyncOutputStream& output, uint64_t amount) { + return kj::unoptimizedPumpTo(*this, output, amount); +} + +kj::Promise PausableReadAsyncIoStream::write(const void* buffer, size_t size) { + return inner->write(buffer, size).attach(trackWrite()); +} + +kj::Promise PausableReadAsyncIoStream::write( + kj::ArrayPtr> pieces) { + return inner->write(pieces).attach(trackWrite()); +} + +kj::Maybe> PausableReadAsyncIoStream::tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount) { + auto result = inner->tryPumpFrom(input, amount); + KJ_IF_MAYBE(r, result) { + return r->attach(trackWrite()); + } else { + return nullptr; + } +} + +kj::Promise PausableReadAsyncIoStream::whenWriteDisconnected() { + return inner->whenWriteDisconnected(); +} + +void PausableReadAsyncIoStream::shutdownWrite() { + inner->shutdownWrite(); +} + +void PausableReadAsyncIoStream::abortRead() { + inner->abortRead(); +} + +kj::Maybe PausableReadAsyncIoStream::getFd() const { + return inner->getFd(); +} + +void PausableReadAsyncIoStream::pause() { + KJ_IF_MAYBE(pausable, maybePausableRead) { + pausable->pause(); + } +} + +void PausableReadAsyncIoStream::unpause() { + KJ_IF_MAYBE(pausable, maybePausableRead) { + pausable->unpause(); + } +} + +bool PausableReadAsyncIoStream::getCurrentlyReading() { + return currentlyReading; +} + +bool PausableReadAsyncIoStream::getCurrentlyWriting() { + return currentlyWriting; +} + +kj::Own PausableReadAsyncIoStream::takeStream() { + return kj::mv(inner); +} + +void PausableReadAsyncIoStream::replaceStream(kj::Own stream) { + inner = kj::mv(stream); +} + +void PausableReadAsyncIoStream::reject(kj::Exception&& exc) { + KJ_IF_MAYBE(pausable, maybePausableRead) { + pausable->reject(kj::mv(exc)); + } +} + // ======================================================================================= namespace { @@ -3752,11 +5635,11 @@ public: auto refcounted = getClient(); auto result = refcounted->client->request(method, url, headers, expectedBodySize); result.body = result.body.attach(kj::addRef(*refcounted)); - result.response = result.response.then(kj::mvCapture(refcounted, - [](kj::Own&& refcounted, Response&& response) { + result.response = result.response.then( + [refcounted=kj::mv(refcounted)](Response&& response) mutable { response.body = response.body.attach(kj::mv(refcounted)); return kj::mv(response); - })); + }); return result; } @@ -3764,8 +5647,8 @@ public: kj::StringPtr url, const HttpHeaders& headers) override { auto refcounted = getClient(); auto result = refcounted->client->openWebSocket(url, headers); - return result.then(kj::mvCapture(refcounted, - [](kj::Own&& refcounted, WebSocketResponse&& response) { + return result.then( + [refcounted=kj::mv(refcounted)](WebSocketResponse&& response) mutable { KJ_SWITCH_ONEOF(response.webSocketOrBody) { KJ_CASE_ONEOF(body, kj::Own) { response.webSocketOrBody = body.attach(kj::mv(refcounted)); @@ -3780,7 +5663,17 @@ public: } } return kj::mv(response); - })); + }); + } + + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) override { + auto refcounted = getClient(); + auto request = refcounted->client->connect(host, headers, settings); + return ConnectRequest { + request.status.attach(kj::addRef(*refcounted)), + request.connection.attach(kj::mv(refcounted)) + }; } private: @@ -3875,6 +5768,75 @@ private: } }; +class TransitionaryAsyncIoStream final: public kj::AsyncIoStream { + // This specialised AsyncIoStream is used by NetworkHttpClient to support startTls. +public: + TransitionaryAsyncIoStream(kj::Own unencryptedStream) + : inner(kj::heap(kj::mv(unencryptedStream))) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->tryRead(buffer, minBytes, maxBytes); + } + + kj::Maybe tryGetLength() override { + return inner->tryGetLength(); + } + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return inner->pumpTo(output, amount); + } + + kj::Promise write(const void* buffer, size_t size) override { + return inner->write(buffer, size); + } + + kj::Promise write(kj::ArrayPtr> pieces) override { + return inner->write(pieces); + } + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + return inner->tryPumpFrom(input, amount); + } + + kj::Promise whenWriteDisconnected() override { + return inner->whenWriteDisconnected(); + } + + void shutdownWrite() override { + inner->shutdownWrite(); + } + + void abortRead() override { + inner->abortRead(); + } + + kj::Maybe getFd() const override { + return inner->getFd(); + } + + void startTls( + kj::SecureNetworkWrapper* wrapper, kj::StringPtr expectedServerHostname) { + // Pause any potential pending reads. + inner->pause(); + + KJ_ON_SCOPE_FAILURE({ + inner->reject(KJ_EXCEPTION(FAILED, "StartTls failed.")); + }); + + KJ_ASSERT(!inner->getCurrentlyReading() && !inner->getCurrentlyWriting(), + "Cannot call startTls while reads/writes are outstanding"); + kj::Promise> secureStream = + wrapper->wrapClient(inner->takeStream(), expectedServerHostname); + inner->replaceStream(kj::newPromisedStream(kj::mv(secureStream))); + // Resume any previous pending reads. + inner->unpause(); + } + +private: + kj::Own inner; +}; + class PromiseNetworkAddressHttpClient final: public HttpClient { // An HttpClient which waits for a promise to resolve then forwards all calls to the promised // client. @@ -3915,12 +5877,12 @@ public: // This gets complicated since request() returns a pair of a stream and a promise. auto urlCopy = kj::str(url); auto headersCopy = headers.clone(); - auto combined = promise.addBranch().then(kj::mvCapture(urlCopy, kj::mvCapture(headersCopy, - [this,method,expectedBodySize](HttpHeaders&& headers, kj::String&& url) + auto combined = promise.addBranch().then( + [this,method,expectedBodySize,url=kj::mv(urlCopy), headers=kj::mv(headersCopy)]() -> kj::Tuple, kj::Promise> { auto req = KJ_ASSERT_NONNULL(client)->request(method, url, headers, expectedBodySize); return kj::tuple(kj::mv(req.body), kj::mv(req.response)); - }))); + }); auto split = combined.split(); return { @@ -3937,10 +5899,30 @@ public: } else { auto urlCopy = kj::str(url); auto headersCopy = headers.clone(); - return promise.addBranch().then(kj::mvCapture(urlCopy, kj::mvCapture(headersCopy, - [this](HttpHeaders&& headers, kj::String&& url) { + return promise.addBranch().then( + [this,url=kj::mv(urlCopy),headers=kj::mv(headersCopy)]() { return KJ_ASSERT_NONNULL(client)->openWebSocket(url, headers); - }))); + }); + } + } + + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) override { + KJ_IF_MAYBE(c, client) { + return c->get()->connect(host, headers, settings); + } else { + auto split = promise.addBranch().then( + [this, host=kj::str(host), headers=headers.clone(), settings]() mutable + -> kj::Tuple, + kj::Promise>> { + auto request = KJ_ASSERT_NONNULL(client)->connect(host, headers, kj::mv(settings)); + return kj::tuple(kj::mv(request.status), kj::mv(request.connection)); + }).split(); + + return ConnectRequest { + kj::mv(kj::get<0>(split)), + kj::newPromisedStream(kj::mv(kj::get<1>(split))) + }; } } @@ -3992,6 +5974,58 @@ public: return getClient(parsed).openWebSocket(path, headersCopy); } + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, + HttpConnectSettings connectSettings) override { + // We want to connect directly instead of going through a proxy here. + // https://github.com/capnproto/capnproto/pull/1454#discussion_r900414879 + kj::Maybe>> addr; + if (connectSettings.useTls) { + kj::Network& tlsNet = KJ_REQUIRE_NONNULL(tlsNetwork, "this HttpClient doesn't support TLS"); + addr = tlsNet.parseAddress(host); + } else { + addr = network.parseAddress(host); + } + + auto split = KJ_ASSERT_NONNULL(addr).then([this](auto address) { + return address->connect().then([this](auto connection) + -> kj::Tuple, + kj::Promise>> { + return kj::tuple( + ConnectRequest::Status( + 200, + kj::str("OK"), + kj::heap(responseHeaderTable) // Empty headers + ), + kj::mv(connection)); + }).attach(kj::mv(address)); + }).split(); + + auto connection = kj::newPromisedStream(kj::mv(kj::get<1>(split))); + + if (!connectSettings.useTls) { + KJ_IF_MAYBE(wrapper, settings.tlsContext) { + KJ_IF_MAYBE(tlsStarter, connectSettings.tlsStarter) { + auto transitConnectionRef = kj::refcountedWrapper( + kj::heap(kj::mv(connection))); + Function(kj::StringPtr)> cb = + [wrapper, ref1 = transitConnectionRef->addWrappedRef()]( + kj::StringPtr expectedServerHostname) mutable { + ref1->startTls(wrapper, expectedServerHostname); + return kj::READY_NOW; + }; + connection = transitConnectionRef->addWrappedRef(); + *tlsStarter = kj::mv(cb); + } + } + } + + return ConnectRequest { + kj::mv(kj::get<0>(split)), + kj::mv(connection) + }; + } + private: kj::Timer& timer; const HttpHeaderTable& responseHeaderTable; @@ -4103,6 +6137,7 @@ namespace { class ConcurrencyLimitingHttpClient final: public HttpClient { public: + KJ_DISALLOW_COPY_AND_MOVE(ConcurrencyLimitingHttpClient); ConcurrencyLimitingHttpClient( kj::HttpClient& inner, uint maxConcurrentRequests, kj::Function countChangedCallback) @@ -4110,6 +6145,16 @@ public: maxConcurrentRequests(maxConcurrentRequests), countChangedCallback(kj::mv(countChangedCallback)) {} + ~ConcurrencyLimitingHttpClient() noexcept(false) { + if (concurrentRequests > 0) { + static bool logOnce KJ_UNUSED = ([&] { + KJ_LOG(ERROR, "ConcurrencyLimitingHttpClient getting destroyed when concurrent requests " + "are still active", concurrentRequests); + return true; + })(); + } + } + Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, kj::Maybe expectedBodySize = nullptr) override { if (concurrentRequests < maxConcurrentRequests) { @@ -4164,6 +6209,35 @@ public: return kj::mv(promise); } + ConnectRequest connect( + kj::StringPtr host, const kj::HttpHeaders& headers, HttpConnectSettings settings) override { + if (concurrentRequests < maxConcurrentRequests) { + auto counter = ConnectionCounter(*this); + auto response = inner.connect(host, headers, settings); + fireCountChanged(); + return attachCounter(kj::mv(response), kj::mv(counter)); + } + + auto paf = kj::newPromiseAndFulfiller(); + + auto split = paf.promise + .then([this, host=kj::str(host), headers=headers.clone(), settings] + (ConnectionCounter&& counter) mutable + -> kj::Tuple, + kj::Promise>> { + auto request = attachCounter(inner.connect(host, headers, settings), kj::mv(counter)); + return kj::tuple(kj::mv(request.status), kj::mv(request.connection)); + }).split(); + + pendingRequests.push(kj::mv(paf.fulfiller)); + fireCountChanged(); + + return ConnectRequest { + kj::mv(kj::get<0>(split)), + kj::newPromisedStream(kj::mv(kj::get<1>(split))) + }; + } + private: struct ConnectionCounter; @@ -4202,12 +6276,15 @@ private: }; void serviceQueue() { - if (concurrentRequests >= maxConcurrentRequests) { return; } - if (pendingRequests.empty()) { return; } - - auto fulfiller = kj::mv(pendingRequests.front()); - pendingRequests.pop(); - fulfiller->fulfill(ConnectionCounter(*this)); + while (concurrentRequests < maxConcurrentRequests && !pendingRequests.empty()) { + auto fulfiller = kj::mv(pendingRequests.front()); + pendingRequests.pop(); + // ConnectionCounter's destructor calls this function, so we can avoid unnecessary recursion + // if we only create a ConnectionCounter when we find a waiting fulfiller. + if (fulfiller->isWaiting()) { + fulfiller->fulfill(ConnectionCounter(*this)); + } + } } void fireCountChanged() { @@ -4251,6 +6328,21 @@ private: }; }); } + + static ConnectRequest attachCounter( + ConnectRequest&& request, + ConnectionCounter&& counter) { + // Notice here that we are only attaching the counter to the connection stream. In the case + // where the connect tunnel request is rejected and the status promise resolves with an + // errorBody, there is a possibility that the consuming code might drop the connection stream + // and the counter while the error body stream is still be consumed. Technically speaking that + // means we could potentially exceed our concurrency limit temporarily but we consider that + // acceptable here since the error body is an exception path (plus not requiring that we + // attach to the errorBody keeps ConnectionCounter from having to be a refcounted heap + // allocation). + request.connection = request.connection.attach(kj::mv(counter)); + return kj::mv(request); + } }; } @@ -4266,42 +6358,6 @@ kj::Own newConcurrencyLimitingHttpClient( namespace { -class NullInputStream final: public kj::AsyncInputStream { -public: - NullInputStream(kj::Maybe expectedLength = size_t(0)) - : expectedLength(expectedLength) {} - - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return size_t(0); - } - - kj::Maybe tryGetLength() override { - return expectedLength; - } - - kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { - return uint64_t(0); - } - -private: - kj::Maybe expectedLength; -}; - -class NullOutputStream final: public kj::AsyncOutputStream { -public: - Promise write(const void* buffer, size_t size) override { - return kj::READY_NOW; - } - Promise write(ArrayPtr> pieces) override { - return kj::READY_NOW; - } - Promise whenWriteDisconnected() override { - return kj::NEVER_DONE; - } - - // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. -}; - class HttpClientAdapter final: public HttpClient { public: HttpClientAdapter(HttpService& service): service(service) {} @@ -4359,8 +6415,60 @@ public: return paf.promise.attach(kj::mv(responder)); } - kj::Promise> connect(kj::StringPtr host) override { - return service.connect(kj::mv(host)); + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) override { + // We have to clone the host and the headers because HttpServer implementation are allowed to + // assusme that they remain valid until the service handler completes whereas HttpClient callers + // are allowed to destroy them immediately after the call. + auto hostCopy = kj::str(host); + auto headersCopy = kj::heap(headers.clone()); + + // 1. Create a new TwoWayPipe, one will be returned with the ConnectRequest, + // the other will be held by the ConnectResponseImpl. + auto pipe = kj::newTwoWayPipe(); + + // 2. Create a promise/fulfiller pair for the status. The promise will be + // returned with the ConnectResponse, the fulfiller will be held by the + // ConnectResponseImpl. + auto paf = kj::newPromiseAndFulfiller(); + + // 3. Create the ConnectResponseImpl + auto response = kj::refcounted(kj::mv(paf.fulfiller), + kj::mv(pipe.ends[0])); + + // 5. Call service.connect, passing in the tunnel. + // The call to tunnel->getConnectStream() returns a guarded stream that will buffer + // writes until the status is indicated by calling accept/reject. + auto connectStream = response->getConnectStream(); + auto promise = service.connect(hostCopy, *headersCopy, *connectStream, *response, settings) + .eagerlyEvaluate([response=kj::mv(response), + host=kj::mv(hostCopy), + headers=kj::mv(headersCopy), + connectStream=kj::mv(connectStream)](kj::Exception&& ex) mutable { + // A few things need to happen here. + // 1. We'll log the exception. + // 2. We'll break the pipe. + // 3. We'll reject the status promise if it is still pending. + // + // We'll do all of this within the ConnectResponseImpl, however, since it + // maintains the state necessary here. + response->handleException(kj::mv(ex), kj::mv(connectStream)); + }); + + // TODO(bug): There's a challenge with attaching the service.connect promise to the + // connection stream below in that the client will likely drop the connection as soon + // as it reads EOF, but the promise representing the service connect() call may still + // be running and want to do some cleanup after it has sent EOF. That cleanup will be + // canceled. For regular HTTP calls, DelayedEofInputStream was created to address this + // exact issue but with connect() being bidirectional it's rather more difficult. We + // want a delay similar to what DelayedEofInputStream adds but only when both directions + // have been closed. That currently is not possible until we have an alternative to + // shutdownWrite() that returns a Promise (e.g. Promise end()). For now, we can + // live with the current limitation. + return ConnectRequest { + kj::mv(paf.promise), + pipe.ends[1].attach(kj::mv(promise)), + }; } private: @@ -4658,6 +6766,119 @@ private: kj::Own> fulfiller; kj::Promise task = nullptr; }; + + class ConnectResponseImpl final: public HttpService::ConnectResponse, public kj::Refcounted { + public: + ConnectResponseImpl( + kj::Own> fulfiller, + kj::Own stream) + : fulfiller(kj::mv(fulfiller)), + streamAndFulfiller(initStreamsAndFulfiller(kj::mv(stream))) {} + + ~ConnectResponseImpl() noexcept(false) { + if (fulfiller->isWaiting() || streamAndFulfiller.fulfiller->isWaiting()) { + auto ex = KJ_EXCEPTION(FAILED, + "service's connect() implementation never called accept() nor reject()"); + if (fulfiller->isWaiting()) { + fulfiller->reject(kj::cp(ex)); + } + if (streamAndFulfiller.fulfiller->isWaiting()) { + streamAndFulfiller.fulfiller->reject(kj::mv(ex)); + } + } + } + + void accept(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) override { + KJ_REQUIRE(statusCode >= 200 && statusCode < 300, "the statusCode must be 2xx for accept"); + respond(statusCode, statusText, headers); + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_REQUIRE(statusCode < 200 || statusCode >= 300, + "the statusCode must not be 2xx for reject."); + auto pipe = kj::newOneWayPipe(); + respond(statusCode, statusText, headers, kj::mv(pipe.in)); + return kj::mv(pipe.out); + } + + private: + struct StreamsAndFulfiller { + // guarded is the wrapped/guarded stream that wraps a reference to + // the underlying stream but blocks reads until the connection is accepted + // or rejected. + // This will be handed off when getConnectStream() is called. + // The fulfiller is used to resolve the guard for the second stream. This will + // be fulfilled or rejected when accept/reject is called. + kj::Own guarded; + kj::Own> fulfiller; + }; + + kj::Own> fulfiller; + StreamsAndFulfiller streamAndFulfiller; + bool connectStreamDetached = false; + + StreamsAndFulfiller initStreamsAndFulfiller(kj::Own stream) { + auto paf = kj::newPromiseAndFulfiller(); + auto guarded = kj::heap( + kj::mv(stream), + kj::Maybe(nullptr), + kj::mv(paf.promise)); + return StreamsAndFulfiller { + kj::mv(guarded), + kj::mv(paf.fulfiller) + }; + } + + void handleException(kj::Exception&& ex, kj::Own connectStream) { + // Log the exception... + KJ_LOG(ERROR, "Error in HttpClientAdapter connect()", kj::cp(ex)); + // Reject the status promise if it is still pending... + if (fulfiller->isWaiting()) { + fulfiller->reject(kj::cp(ex)); + } + if (streamAndFulfiller.fulfiller->isWaiting()) { + // If the guard hasn't yet ben released, we can fail the pending reads by + // rejecting the fulfiller here. + streamAndFulfiller.fulfiller->reject(kj::mv(ex)); + } else { + // The guard has already been released at this point. + // TODO(connect): How to properly propagate the actual exception to the + // connect stream? Here we "simply" shut it down. + connectStream->abortRead(); + connectStream->shutdownWrite(); + } + } + + kj::Own getConnectStream() { + KJ_ASSERT(!connectStreamDetached, "the connect stream was already detached"); + connectStreamDetached = true; + return streamAndFulfiller.guarded.attach(kj::addRef(*this)); + } + + void respond(uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe> errorBody = nullptr) { + if (errorBody == nullptr) { + streamAndFulfiller.fulfiller->fulfill(); + } else { + streamAndFulfiller.fulfiller->reject( + KJ_EXCEPTION(DISCONNECTED, "the connect request was rejected")); + } + fulfiller->fulfill(HttpClient::ConnectRequest::Status( + statusCode, + kj::str(statusText), + kj::heap(headers.clone()), + kj::mv(errorBody))); + } + + friend class HttpClientAdapter; + }; + }; } // namespace @@ -4693,7 +6914,7 @@ public: return promise.ignoreResult().attach(kj::mv(out), kj::mv(innerResponse.body)); })); - return kj::joinPromises(promises.finish()); + return kj::joinPromisesFailFast(promises.finish()); } else { return client.openWebSocket(url, headers) .then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise { @@ -4703,7 +6924,7 @@ public: auto promises = kj::heapArrayBuilder>(2); promises.add(ws->pumpTo(*ws2)); promises.add(ws2->pumpTo(*ws)); - return kj::joinPromises(promises.finish()).attach(kj::mv(ws), kj::mv(ws2)); + return kj::joinPromisesFailFast(promises.finish()).attach(kj::mv(ws), kj::mv(ws2)); } KJ_CASE_ONEOF(body, kj::Own) { auto out = response.send( @@ -4718,8 +6939,69 @@ public: } } - kj::Promise> connect(kj::StringPtr host) override { - return client.connect(kj::mv(host)); + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + HttpConnectSettings settings) override { + KJ_REQUIRE(!headers.isWebSocket(), "WebSocket upgrade headers are not permitted in a connect."); + + auto request = client.connect(host, headers, settings); + + // This operates optimistically. In order to support pipelining, we connect the + // input and outputs streams immediately, even if we're not yet certain that the + // tunnel can actually be established. + auto promises = kj::heapArrayBuilder>(2); + + // For the inbound pipe (from the clients stream to the passed in stream) + // We want to guard reads pending the acceptance of the tunnel. If the + // tunnel is not accepted, the guard will be rejected, causing pending + // reads to fail. + auto paf = kj::newPromiseAndFulfiller>(); + auto io = kj::heap( + kj::mv(request.connection), + kj::mv(paf.promise) /* read guard */, + kj::READY_NOW /* write guard */); + + // Writing from connection to io is unguarded and allowed immediately. + promises.add(connection.pumpTo(*io).then([&io=*io](uint64_t size) { + io.shutdownWrite(); + })); + + promises.add(io->pumpTo(connection).then([&connection](uint64_t size) { + connection.shutdownWrite(); + })); + + auto pumpPromise = kj::joinPromisesFailFast(promises.finish()); + + return request.status.then( + [&response,&connection,fulfiller=kj::mv(paf.fulfiller), + pumpPromise=kj::mv(pumpPromise)] + (HttpClient::ConnectRequest::Status status) mutable -> kj::Promise { + if (status.statusCode >= 200 && status.statusCode < 300) { + // Release the read guard! + fulfiller->fulfill(kj::Maybe(nullptr)); + response.accept(status.statusCode, status.statusText, *status.headers); + return kj::mv(pumpPromise); + } else { + // If the connect request is rejected, we want to shutdown the tunnel + // and pipeline the status.errorBody to the AsyncOutputStream returned by + // reject if it exists. + pumpPromise = nullptr; + connection.shutdownWrite(); + fulfiller->reject(KJ_EXCEPTION(DISCONNECTED, "the connect request was rejected")); + KJ_IF_MAYBE(errorBody, status.errorBody) { + auto out = response.reject(status.statusCode, status.statusText, *status.headers, + errorBody->get()->tryGetLength()); + return (*errorBody)->pumpTo(*out).then([](uint64_t) -> kj::Promise { + return kj::READY_NOW; + }).attach(kj::mv(out), kj::mv(*errorBody)); + } else { + response.reject(status.statusCode, status.statusText, *status.headers, (uint64_t)0); + return kj::READY_NOW; + } + } + }).attach(kj::mv(io)); } private: @@ -4746,20 +7028,28 @@ kj::Promise HttpService::Response::sendError( return sendError(statusCode, statusText, HttpHeaders(headerTable)); } -kj::Promise> HttpService::connect(kj::StringPtr host) { +kj::Promise HttpService::connect( + kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) { KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService"); } class HttpServer::Connection final: private HttpService::Response, + private HttpService::ConnectResponse, private HttpServerErrorHandler { public: Connection(HttpServer& server, kj::AsyncIoStream& stream, - HttpService& service) + SuspendableHttpServiceFactory factory, kj::Maybe suspendedRequest, + bool wantCleanDrain) : server(server), stream(stream), - service(service), - httpInput(stream, server.requestHeaderTable), - httpOutput(stream) { + factory(kj::mv(factory)), + httpInput(makeHttpInput(stream, server.requestHeaderTable, kj::mv(suspendedRequest))), + httpOutput(stream), + wantCleanDrain(wantCleanDrain) { ++server.connectionCount; } ~Connection() noexcept(false) { @@ -4783,23 +7073,69 @@ public: return kj::mv(promise); } + KJ_IF_MAYBE(p, tunnelRejected) { + // reject() was called to reject a CONNECT request. Finish sending and close the connection. + // Don't log the exception because it's probably a side-effect of this. + auto promise = kj::mv(*p); + tunnelRejected = nullptr; + return kj::mv(promise); + } + return sendError(kj::mv(e)); }); } + SuspendedRequest suspend(SuspendableRequest& suspendable) { + KJ_REQUIRE(httpInput.canSuspend(), + "suspend() may only be called before the request body is consumed"); + KJ_DEFER(suspended = true); + auto released = httpInput.releaseBuffer(); + return { + kj::mv(released.buffer), + released.leftover, + suspendable.method, + suspendable.url, + suspendable.headers.cloneShallow(), + }; + } + private: HttpServer& server; kj::AsyncIoStream& stream; - HttpService& service; + + SuspendableHttpServiceFactory factory; + // Creates a new kj::Own for each request we handle on this connection. + HttpInputStreamImpl httpInput; HttpOutputStream httpOutput; - kj::Maybe currentMethod; + kj::Maybe> currentMethod; bool timedOut = false; bool closed = false; bool upgraded = false; - bool webSocketClosed = false; + bool webSocketOrConnectClosed = false; bool closeAfterSend = false; // True if send() should set Connection: close. + bool wantCleanDrain = false; + bool suspended = false; kj::Maybe> webSocketError; + kj::Maybe> tunnelRejected; + kj::Maybe>> tunnelWriteGuard; + + static HttpInputStreamImpl makeHttpInput( + kj::AsyncIoStream& stream, + const kj::HttpHeaderTable& table, + kj::Maybe suspendedRequest) { + // Constructor helper function to create our HttpInputStreamImpl. + + KJ_IF_MAYBE(sr, suspendedRequest) { + return HttpInputStreamImpl(stream, + sr->buffer.releaseAsChars(), + sr->leftover.asChars(), + sr->method, + sr->url, + kj::mv(sr->headers)); + } + return HttpInputStreamImpl(stream, table); + } kj::Promise loop(bool firstRequest) { if (!firstRequest && server.draining && httpInput.isCleanDrain()) { @@ -4818,7 +7154,32 @@ private: if (httpInput.isCleanDrain()) { // If we haven't buffered any data, then we can safely drain here, so allow the wait to // be canceled by the onDrain promise. - timeoutPromise = timeoutPromise.exclusiveJoin(server.onDrain.addBranch()); + auto cleanDrainPromise = server.onDrain.addBranch() + .then([this]() -> kj::Promise { + // This is a little tricky... drain() has been called, BUT we could have read some data + // into the buffer in the meantime, and we don't want to lose that. If any data has + // arrived, then we have no choice but to read the rest of the request and respond to + // it. + if (!httpInput.isCleanDrain()) { + return kj::NEVER_DONE; + } + + // OK... As far as we know, no data has arrived in the buffer. However, unfortunately, + // we don't *really* know that, because read() is asynchronous. It may have already + // delivered some bytes, but we just haven't received the notification yet, because it's + // still queued on the event loop. As a horrible hack, we use evalLast(), so that any + // such pending notifications get a chance to be delivered. + // TODO(someday): Does this actually work on Windows, where the notification could also + // be queued on the IOCP? + return kj::evalLast([this]() -> kj::Promise { + if (httpInput.isCleanDrain()) { + return kj::READY_NOW; + } else { + return kj::NEVER_DONE; + } + }); + }); + timeoutPromise = timeoutPromise.exclusiveJoin(kj::mv(cleanDrainPromise)); } firstByte = firstByte.exclusiveJoin(timeoutPromise.then([this]() -> bool { @@ -4829,7 +7190,7 @@ private: auto receivedHeaders = firstByte .then([this,firstRequest](bool hasData) - -> kj::Promise { + -> kj::Promise { if (hasData) { auto readHeaders = httpInput.readRequestHeaders(); if (!firstRequest) { @@ -4837,7 +7198,7 @@ private: // the first byte of a pipeline response. readHeaders = readHeaders.exclusiveJoin( server.timer.afterDelay(server.settings.headerTimeout) - .then([this]() -> HttpHeaders::RequestOrProtocolError { + .then([this]() -> HttpHeaders::RequestConnectOrProtocolError { timedOut = true; return HttpHeaders::ProtocolError { 408, "Request Timeout", @@ -4850,7 +7211,7 @@ private: // Client closed connection or pipeline timed out with no bytes received. This is not an // error, so don't report one. this->closed = true; - return HttpHeaders::RequestOrProtocolError(HttpHeaders::ProtocolError { + return HttpHeaders::RequestConnectOrProtocolError(HttpHeaders::ProtocolError { 408, "Request Timeout", "Client closed connection or connection timeout " "while waiting for request headers.", nullptr @@ -4860,9 +7221,11 @@ private: if (firstRequest) { // On the first request, the header timeout starts ticking immediately upon request opening. + // NOTE: Since we assume that the client wouldn't have formed a connection if they did not + // intend to send a request, we immediately treat this connection as having an active + // request, i.e. we do NOT cancel it if drain() is called. auto timeoutPromise = server.timer.afterDelay(server.settings.headerTimeout) - .exclusiveJoin(server.onDrain.addBranch()) - .then([this]() -> HttpHeaders::RequestOrProtocolError { + .then([this]() -> HttpHeaders::RequestConnectOrProtocolError { timedOut = true; return HttpHeaders::ProtocolError { 408, "Request Timeout", @@ -4873,7 +7236,7 @@ private: } return receivedHeaders - .then([this](HttpHeaders::RequestOrProtocolError&& requestOrProtocolError) + .then([this](HttpHeaders::RequestConnectOrProtocolError&& requestOrProtocolError) -> kj::Promise { if (timedOut) { // Client took too long to send anything, so we're going to close the connection. In @@ -4901,20 +7264,95 @@ private: } KJ_SWITCH_ONEOF(requestOrProtocolError) { + KJ_CASE_ONEOF(request, HttpHeaders::ConnectRequest) { + auto& headers = httpInput.getHeaders(); + + currentMethod = HttpConnectMethod(); + + // The HTTP specification says that CONNECT requests have no meaningful payload + // but stops short of saying that CONNECT *cannot* have a payload. Implementations + // can choose to either accept payloads or reject them. We choose to reject it. + // Specifically, if there are Content-Length or Transfer-Encoding headers in the + // request headers, we'll automatically reject the CONNECT request. + // + // The key implication here is that any data that immediately follows the headers + // block of the CONNECT request is considered to be part of the tunnel if it is + // established. + + KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) { + return sendError(HttpHeaders::ProtocolError { + 400, + "Bad Request"_kj, + "Bad Request"_kj, + nullptr, + }); + } + KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) { + return sendError(HttpHeaders::ProtocolError { + 400, + "Bad Request"_kj, + "Bad Request"_kj, + nullptr, + }); + } + + SuspendableRequest suspendable(*this, HttpConnectMethod(), request.authority, headers); + auto maybeService = factory(suspendable); + + if (suspended) { + return false; + } + + auto service = KJ_ASSERT_NONNULL(kj::mv(maybeService), + "SuspendableHttpServiceFactory did not suspend, but returned nullptr."); + auto connectStream = getConnectStream(); + auto promise = service->connect( + request.authority, headers, *connectStream, *this, {}) + .attach(kj::mv(service), kj::mv(connectStream)); + return promise.then([this]() mutable -> kj::Promise { + KJ_IF_MAYBE(p, tunnelRejected) { + // reject() was called to reject a CONNECT attempt. + // Finish sending and close the connection. + auto promise = kj::mv(*p); + tunnelRejected = nullptr; + return kj::mv(promise); + } + + if (httpOutput.isBroken()) { + return false; + } + + return httpOutput.flush().then([]() mutable -> kj::Promise { + // There is really no reasonable path to reusing a CONNECT connection. + return false; + }); + }); + } KJ_CASE_ONEOF(request, HttpHeaders::Request) { auto& headers = httpInput.getHeaders(); currentMethod = request.method; - auto body = httpInput.getEntityBody( - HttpInputStreamImpl::REQUEST, request.method, 0, headers); + + SuspendableRequest suspendable(*this, request.method, request.url, headers); + auto maybeService = factory(suspendable); + + if (suspended) { + return false; + } + + auto service = KJ_ASSERT_NONNULL(kj::mv(maybeService), + "SuspendableHttpServiceFactory did not suspend, but returned nullptr."); // TODO(perf): If the client disconnects, should we cancel the response? Probably, to // prevent permanent deadlock. It's slightly weird in that arguably the client should // be able to shutdown the upstream but still wait on the downstream, but I believe many // other HTTP servers do similar things. - auto promise = service.request( - request.method, request.url, headers, *body, *this); + auto body = httpInput.getEntityBody( + HttpInputStreamImpl::REQUEST, request.method, 0, headers); + + auto promise = service->request( + request.method, request.url, headers, *body, *this).attach(kj::mv(service)); return promise.then([this, body = kj::mv(body)]() mutable -> kj::Promise { // Response done. Await next request. @@ -4927,7 +7365,7 @@ private: if (upgraded) { // We've upgraded to WebSocket, and by now we should have closed the WebSocket. - if (!webSocketClosed) { + if (!webSocketOrConnectClosed) { // This is gonna segfault later so abort now instead. KJ_LOG(FATAL, "Accepted WebSocket object must be destroyed before HttpService " "request handler completes."); @@ -4956,9 +7394,14 @@ private: if (httpInput.canReuse()) { // Things look clean. Go ahead and accept the next request. - // Note that we don't have to handle server.draining here because we'll take care of - // it the next time around the loop. - return loop(false); + if (closeAfterSend) { + // We sent Connection: close, so drop the connection now. + return false; + } else { + // Note that we don't have to handle server.draining here because we'll take care + // of it the next time around the loop. + return loop(false); + } } else { // Apparently, the application did not read the request body. Maybe this is a bug, // or maybe not: maybe the client tried to upload too much data and the application @@ -4971,13 +7414,25 @@ private: // within its rights to start a new request. If we close the socket now, we might // interrupt that new request. // + // Or maybe we did send `Connection: close`, as indicated by `closeAfterSend` being + // true. Even in that case, we should still try to read and ignore the request, + // otherwise when we close the connection the client may get a "connection reset" + // error before they get a chance to actually read the response body that we sent + // them. + // // There's no way we can get out of this perfectly cleanly. HTTP just isn't good // enough at connection management. The best we can do is give the client some grace // period and then abort the connection. auto dummy = kj::heap(); - auto lengthGrace = body->pumpTo(*dummy, server.settings.canceledUploadGraceBytes) - .then([this](size_t amount) { + auto lengthGrace = kj::evalNow([&]() { + return body->pumpTo(*dummy, server.settings.canceledUploadGraceBytes); + }).catch_([](kj::Exception&& e) -> uint64_t { + // Reading from the input failed in some way. This may actually be the whole + // reason we got here in the first place so don't propagate this error, just + // give up on discarding the input. + return 0; // This zero is ignored but `canReuse()` will return false below. + }).then([this](uint64_t amount) { if (httpInput.canReuse()) { // Success, we can continue. return true; @@ -4993,11 +7448,12 @@ private: return lengthGrace.exclusiveJoin(kj::mv(timeGrace)) .then([this](bool clean) -> kj::Promise { - if (clean) { + if (clean && !closeAfterSend) { // We recovered. Continue loop. return loop(false); } else { - // Client still not done. Return broken. + // Client still not done, or we sent Connection: close and so want to drop the + // connection anyway. Return broken. return false; } }); @@ -5030,37 +7486,46 @@ private: if (!closeAfterSend) { // Check if application wants us to close connections. - KJ_IF_MAYBE(c, server.settings.callbacks) { + // + // If the application used listenHttpClientDrain() to listen, then it expects that after a + // clean drain, the connection is still open and can receive more requests. Otherwise, after + // receiving drain(), we will close the connection, so we should send a `Connection: close` + // header. + if (server.draining && !wantCleanDrain) { + closeAfterSend = true; + } else KJ_IF_MAYBE(c, server.settings.callbacks) { + // The application has registered its own callback to decide whether to send + // `Connection: close`. if (c->shouldClose()) { closeAfterSend = true; } } } - // TODO(0.10): If `server.draining`, we should probably set `closeAfterSend` -- UNLESS the - // connection was created using listenHttpCleanDrain(), in which case the application may - // intend to continue using the connection. - if (closeAfterSend) { connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "close"; } + bool isHeadRequest = method.tryGet().map([](auto& m) { + return m == HttpMethod::HEAD; + }).orDefault(false); + if (statusCode == 204 || statusCode == 304) { // No entity-body. } else if (statusCode == 205) { // Status code 205 also has no body, but unlike 204 and 304, it must explicitly encode an - // empty body, e.g. using content-length: 0. I'm guessing this is one of those things, where - // some early clients expected an explicit body while others assumed an empty body, and so - // the standard had to choose the common denominator. + // empty body, e.g. using content-length: 0. I'm guessing this is one of those things, + // where some early clients expected an explicit body while others assumed an empty body, + // and so the standard had to choose the common denominator. // // Spec: https://tools.ietf.org/html/rfc7231#section-6.3.6 connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = "0"; } else KJ_IF_MAYBE(s, expectedBodySize) { - // HACK: We interpret a zero-length expected body length on responses to HEAD requests to mean - // "don't set a Content-Length header at all." This provides a way to omit a body header on - // HEAD responses with non-null-body status codes. This is a hack that *only* makes sense - // for HEAD responses. - if (method != HttpMethod::HEAD || *s > 0) { + // HACK: We interpret a zero-length expected body length on responses to HEAD requests to + // mean "don't set a Content-Length header at all." This provides a way to omit a body + // header on HEAD responses with non-null-body status codes. This is a hack that *only* + // makes sense for HEAD responses. + if (!isHeadRequest || *s > 0) { lengthStr = kj::str(*s); connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = lengthStr; } @@ -5071,7 +7536,7 @@ private: // For HEAD requests, if the application specified a Content-Length or Transfer-Encoding // header, use that instead of whatever we decided above. kj::ArrayPtr connectionHeadersArray = connectionHeaders; - if (method == HttpMethod::HEAD) { + if (isHeadRequest) { if (headers.get(HttpHeaderId::CONTENT_LENGTH) != nullptr || headers.get(HttpHeaderId::TRANSFER_ENCODING) != nullptr) { connectionHeadersArray = connectionHeadersArray @@ -5083,7 +7548,7 @@ private: statusCode, statusText, connectionHeadersArray)); kj::Own bodyStream; - if (method == HttpMethod::HEAD) { + if (isHeadRequest) { // Ignore entity-body. httpOutput.finishBody(); return heap(); @@ -5104,13 +7569,9 @@ private: "can't call acceptWebSocket() if the request headers didn't have Upgrade: WebSocket"); auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); - // Unlike send(), we neither need nor want to null out currentMethod. The error cases below - // depend on it being non-null to allow error responses to be sent, and the happy path expects - // it to be GET. - - if (method != HttpMethod::GET) { - return sendWebSocketError("WebSocket must be initiated with a GET request."); - } + KJ_REQUIRE(method.tryGet().map([](auto& m) { + return m == HttpMethod::GET; + }).orDefault(false), "WebSocket must be initiated with a GET request."); if (requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") { return sendWebSocketError("The requested WebSocket version is not supported."); @@ -5123,12 +7584,50 @@ private: return sendWebSocketError("Missing Sec-WebSocket-Key"); } + kj::Maybe acceptedParameters; + kj::String agreedParameters; + auto compressionMode = server.settings.webSocketCompressionMode; + if (compressionMode == HttpServerSettings::AUTOMATIC_COMPRESSION) { + // If AUTOMATIC_COMPRESSION is enabled, we ignore the `headers` passed by the application and + // strictly refer to the `requestHeaders` from the client. + KJ_IF_MAYBE(value, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + // Perform compression parameter negotiation. + KJ_IF_MAYBE(config, _::tryParseExtensionOffers(*value)) { + acceptedParameters = kj::mv(*config); + } + } + } else if (compressionMode == HttpServerSettings::MANUAL_COMPRESSION) { + // If MANUAL_COMPRESSION is enabled, we use the `headers` passed in by the application, and + // try to find a configuration that respects both the server's preferred configuration, + // as well as the client's requested configuration. + KJ_IF_MAYBE(value, headers.get(HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + // First, we get the manual configuration using `headers`. + KJ_IF_MAYBE(manualConfig, _::tryParseExtensionOffers(*value)) { + KJ_IF_MAYBE(requestOffers, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + // Next, we to find a configuration that both the client and server can accept. + acceptedParameters = _::tryParseAllExtensionOffers(*requestOffers, *manualConfig); + } + } + } + } + auto websocketAccept = generateWebSocketAccept(key); kj::StringPtr connectionHeaders[HttpHeaders::WEBSOCKET_CONNECTION_HEADERS_COUNT]; connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_ACCEPT] = websocketAccept; connectionHeaders[HttpHeaders::BuiltinIndices::UPGRADE] = "websocket"; connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "Upgrade"; + KJ_IF_MAYBE(parameters, acceptedParameters) { + agreedParameters = _::generateExtensionResponse(*parameters); + connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_EXTENSIONS] = agreedParameters; + } + + // Since we're about to write headers, we should nullify `currentMethod`. This tells + // `sendError(kj::Exception)` (called from `HttpServer::Connection::startLoop()`) not to expose + // the `HttpService::Response&` reference to the HttpServer's error `handleApplicationError()` + // callback. This prevents the error handler from inadvertently trying to send another error on + // the connection. + currentMethod = nullptr; httpOutput.writeHeaders(headers.serializeResponse( 101, "Switching Protocols", connectionHeaders)); @@ -5136,12 +7635,13 @@ private: upgraded = true; // We need to give the WebSocket an Own, but we only have a reference. This is // safe because the application is expected to drop the WebSocket object before returning - // from the request handler. For some extra safety, we check that webSocketClosed has been - // set true when the handler returns. - auto deferNoteClosed = kj::defer([this]() { webSocketClosed = true; }); + // from the request handler. For some extra safety, we check that webSocketOrConnectClosed has + // been set true when the handler returns. + auto deferNoteClosed = kj::defer([this]() { webSocketOrConnectClosed = true; }); kj::Own ownStream(&stream, kj::NullDisposer::instance); return upgradeToWebSocket(ownStream.attach(kj::mv(deferNoteClosed)), - httpInput, httpOutput, nullptr); + httpInput, httpOutput, nullptr, kj::mv(acceptedParameters), + server.settings.webSocketErrorHandler); } kj::Promise sendError(HttpHeaders::ProtocolError protocolError) { @@ -5151,9 +7651,7 @@ private: // HttpService, meaning no response has been sent and we can provide a Response object. auto promise = server.settings.errorHandler.orDefault(*this).handleClientProtocolError( kj::mv(protocolError), *this); - - return promise.then([this]() { return httpOutput.flush(); }) - .then([]() { return false; }); // loop ends after flush + return finishSendingError(kj::mv(promise)); } kj::Promise sendError(kj::Exception&& exception) { @@ -5162,9 +7660,7 @@ private: // We only provide the Response object if we know we haven't already sent a response. auto promise = server.settings.errorHandler.orDefault(*this).handleApplicationError( kj::mv(exception), currentMethod.map([this](auto&&) -> Response& { return *this; })); - - return promise.then([this]() { return httpOutput.flush(); }) - .then([]() { return false; }); // loop ends after flush + return finishSendingError(kj::mv(promise)); } kj::Promise sendError() { @@ -5172,9 +7668,19 @@ private: // We can provide a Response object, since none has already been sent. auto promise = server.settings.errorHandler.orDefault(*this).handleNoResponse(*this); + return finishSendingError(kj::mv(promise)); + } - return promise.then([this]() { return httpOutput.flush(); }) - .then([]() { return false; }); // loop ends after flush + kj::Promise finishSendingError(kj::Promise promise) { + return promise.then([this]() -> kj::Promise { + if (httpOutput.isBroken()) { + // Skip flush for broken streams, since it will throw an exception that may be worse than + // the one we just handled. + return kj::READY_NOW; + } else { + return httpOutput.flush(); + } + }).then([]() { return false; }); // loop ends after flush } kj::Own sendWebSocketError(StringPtr errorMessage) { @@ -5221,6 +7727,57 @@ private: return kj::heap(KJ_EXCEPTION(FAILED, "received bad WebSocket handshake", errorMessage)); } + + kj::Own getConnectStream() { + // Returns an AsyncIoStream over the internal stream but that waits for a Promise to be + // resolved to allow writes after either accept or reject are called. Reads are allowed + // immediately. + KJ_REQUIRE(tunnelWriteGuard == nullptr, "the tunnel stream was already retrieved"); + auto paf = kj::newPromiseAndFulfiller(); + tunnelWriteGuard = kj::mv(paf.fulfiller); + + kj::Own ownStream(&stream, kj::NullDisposer::instance); + auto releasedBuffer = httpInput.releaseBuffer(); + auto deferNoteClosed = kj::defer([this]() { webSocketOrConnectClosed = true; }); + return kj::heap( + kj::heap( + kj::mv(ownStream), + kj::mv(releasedBuffer.buffer), + releasedBuffer.leftover).attach(kj::mv(deferNoteClosed)), + kj::Maybe(nullptr), + kj::mv(paf.promise)); + } + + void accept(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) override { + auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); + currentMethod = nullptr; + KJ_ASSERT(method.is(), "only use accept() with CONNECT requests"); + KJ_REQUIRE(statusCode >= 200 && statusCode < 300, "the statusCode must be 2xx for accept"); + tunnelRejected = nullptr; + + auto& fulfiller = KJ_ASSERT_NONNULL(tunnelWriteGuard, "the tunnel stream was not initialized"); + httpOutput.writeHeaders(headers.serializeResponse(statusCode, statusText)); + auto promise = httpOutput.flush().then([&fulfiller]() { + fulfiller->fulfill(); + }).eagerlyEvaluate(nullptr); + fulfiller = fulfiller.attach(kj::mv(promise)); + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); + KJ_REQUIRE(method.is(), "Only use reject() with CONNECT requests."); + KJ_REQUIRE(statusCode < 200 || statusCode >= 300, "the statusCode must not be 2xx for reject."); + tunnelRejected = Maybe>(true); + + auto& fulfiller = KJ_ASSERT_NONNULL(tunnelWriteGuard, "the tunnel stream was not initialized"); + fulfiller->reject(KJ_EXCEPTION(DISCONNECTED, "the tunnel request was rejected")); + closeAfterSend = true; + return send(statusCode, statusText, headers, expectedBodySize); + } }; HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, @@ -5262,18 +7819,13 @@ kj::Promise HttpServer::listenHttp(kj::ConnectionReceiver& port) { kj::Promise HttpServer::listenLoop(kj::ConnectionReceiver& port) { return port.accept() .then([this,&port](kj::Own&& connection) -> kj::Promise { - if (draining) { - // Can get here if we *just* started draining. - return kj::READY_NOW; - } - - tasks.add(listenHttp(kj::mv(connection))); + tasks.add(kj::evalNow([&]() { return listenHttp(kj::mv(connection)); })); return listenLoop(port); }); } kj::Promise HttpServer::listenHttp(kj::Own connection) { - auto promise = listenHttpCleanDrain(*connection).ignoreResult(); + auto promise = listenHttpImpl(*connection, false /* wantCleanDrain */).ignoreResult(); // eagerlyEvaluate() to maintain historical guarantee that this method eagerly closes the // connection when done. @@ -5281,19 +7833,54 @@ kj::Promise HttpServer::listenHttp(kj::Own connection) } kj::Promise HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection) { - kj::Own obj; + return listenHttpImpl(connection, true /* wantCleanDrain */); +} + +kj::Promise HttpServer::listenHttpImpl(kj::AsyncIoStream& connection, bool wantCleanDrain) { + kj::Own srv; KJ_SWITCH_ONEOF(service) { KJ_CASE_ONEOF(ptr, HttpService*) { - obj = heap(*this, connection, *ptr); + // Fake Own okay because we can assume the HttpService outlives this HttpServer, and we can + // assume `this` HttpServer outlives the returned `listenHttpCleanDrain()` promise, which will + // own the fake Own. + srv = kj::Own(ptr, kj::NullDisposer::instance); } KJ_CASE_ONEOF(func, HttpServiceFactory) { - auto srv = func(connection); - obj = heap(*this, connection, *srv); - obj = obj.attach(kj::mv(srv)); + srv = func(connection); } } + KJ_ASSERT(srv.get() != nullptr); + + return listenHttpImpl(connection, [srv = kj::mv(srv)](SuspendableRequest&) mutable { + // This factory function will be owned by the Connection object, meaning the Connection object + // will own the HttpService. We also know that the Connection object outlives all + // service.request() promises (service.request() is called from a Connection member function). + // The Owns we return from this function are attached to the service.request() promises, + // meaning this factory function will outlive all Owns we return. So, it's safe to return a fake + // Own. + return kj::Own(srv.get(), kj::NullDisposer::instance); + }, nullptr /* suspendedRequest */, wantCleanDrain); +} + +kj::Promise HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection, + SuspendableHttpServiceFactory factory, + kj::Maybe suspendedRequest) { + // Don't close on drain, because a "clean drain" means we return the connection to the + // application still-open between requests so that it can continue serving future HTTP requests + // on it. + return listenHttpImpl(connection, kj::mv(factory), kj::mv(suspendedRequest), + true /* wantCleanDrain */); +} + +kj::Promise HttpServer::listenHttpImpl(kj::AsyncIoStream& connection, + SuspendableHttpServiceFactory factory, + kj::Maybe suspendedRequest, + bool wantCleanDrain) { + auto obj = heap(*this, connection, kj::mv(factory), kj::mv(suspendedRequest), + wantCleanDrain); + // Start reading requests and responding to them, but immediately cancel processing if the client // disconnects. auto promise = obj->startLoop(true) @@ -5304,9 +7891,44 @@ kj::Promise HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection return promise.attach(kj::mv(obj)).eagerlyEvaluate(nullptr); } -void HttpServer::taskFailed(kj::Exception&& exception) { +namespace { +void defaultHandleListenLoopException(kj::Exception&& exception) { KJ_LOG(ERROR, "unhandled exception in HTTP server", exception); } +} // namespace + +void HttpServer::taskFailed(kj::Exception&& exception) { + KJ_IF_MAYBE(handler, settings.errorHandler) { + handler->handleListenLoopException(kj::mv(exception)); + } else { + defaultHandleListenLoopException(kj::mv(exception)); + } +} + +HttpServer::SuspendedRequest::SuspendedRequest( + kj::Array bufferParam, kj::ArrayPtr leftoverParam, + kj::OneOf method, + kj::StringPtr url, HttpHeaders headers) + : buffer(kj::mv(bufferParam)), + leftover(leftoverParam), + method(method), + url(url), + headers(kj::mv(headers)) { + if (leftover.size() > 0) { + // We have a `leftover`; make sure it is a slice of `buffer`. + KJ_ASSERT(leftover.begin() >= buffer.begin() && leftover.begin() <= buffer.end()); + KJ_ASSERT(leftover.end() >= buffer.begin() && leftover.end() <= buffer.end()); + } else { + // We have no `leftover`, but we still expect it to point into `buffer` somewhere. This is + // important so that `messageHeaderEnd` is initialized correctly in HttpInputStreamImpl's + // constructor. + KJ_ASSERT(leftover.begin() >= buffer.begin() && leftover.begin() <= buffer.end()); + } +} + +HttpServer::SuspendedRequest HttpServer::SuspendableRequest::suspend() { + return connection.suspend(*this); +} kj::Promise HttpServerErrorHandler::handleClientProtocolError( HttpHeaders::ProtocolError protocolError, kj::HttpService::Response& response) { @@ -5340,6 +7962,8 @@ kj::Promise HttpServerErrorHandler::handleApplicationError( } KJ_IF_MAYBE(r, response) { + KJ_LOG(INFO, "threw exception while serving HTTP response", exception); + HttpHeaderTable headerTable {}; HttpHeaders headers(headerTable); headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); @@ -5370,6 +7994,10 @@ kj::Promise HttpServerErrorHandler::handleApplicationError( return kj::READY_NOW; } +void HttpServerErrorHandler::handleListenLoopException(kj::Exception&& exception) { + defaultHandleListenLoopException(kj::mv(exception)); +} + kj::Promise HttpServerErrorHandler::handleNoResponse(kj::HttpService::Response& response) { HttpHeaderTable headerTable {}; HttpHeaders headers(headerTable); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/http.h b/libs/EXTERNAL/capnproto/c++/src/kj/compat/http.h index c65a1bb162b..151222a562c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/http.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/http.h @@ -39,6 +39,9 @@ #include #include #include +#include + +KJ_BEGIN_HEADER namespace kj { @@ -90,8 +93,17 @@ KJ_HTTP_FOR_EACH_METHOD(DECLARE_METHOD) #undef DECLARE_METHOD }; +struct HttpConnectMethod {}; +// CONNECT is handled specially and separately from the other HttpMethods. + kj::StringPtr KJ_STRINGIFY(HttpMethod method); +kj::StringPtr KJ_STRINGIFY(HttpConnectMethod method); kj::Maybe tryParseHttpMethod(kj::StringPtr name); +kj::Maybe> tryParseHttpMethodAllowingConnect( + kj::StringPtr name); +// Like tryParseHttpMethod but, as the name suggests, explicitly allows for the CONNECT +// method. Added as a separate function instead of modifying tryParseHttpMethod to avoid +// breaking API changes in existing uses of tryParseHttpMethod. class HttpHeaderTable; @@ -219,7 +231,7 @@ class HttpHeaderTable { kj::Own table; }; - KJ_DISALLOW_COPY(HttpHeaderTable); // Can't copy because HttpHeaderId points to the table. + KJ_DISALLOW_COPY_AND_MOVE(HttpHeaderTable); // Can't copy because HttpHeaderId points to the table. ~HttpHeaderTable() noexcept(false); uint idCount() const; @@ -234,9 +246,20 @@ class HttpHeaderTable { kj::StringPtr idToString(HttpHeaderId id) const; // Get the canonical string name for the given ID. + bool isReady() const; + // Returns true if this HttpHeaderTable either was default constructed or its Builder has + // invoked `build()` and released it. + private: kj::Vector namesById; kj::Own idsByName; + + enum class BuildStatus { + UNSTARTED = 0, + BUILDING = 1, + FINISHED = 2, + }; + BuildStatus buildStatus = BuildStatus::UNSTARTED; }; class HttpHeaders { @@ -287,6 +310,12 @@ class HttpHeaders { kj::Maybe get(HttpHeaderId id) const; // Read a header. + // + // Note that there is intentionally no method to look up a header by string name rather than + // header ID. The intent is that you should always allocate a header ID for any header that you + // care about, so that you can get() it by ID. Headers with registered IDs are stored in an array + // indexed by ID, making lookup fast. Headers without registered IDs are stored in a separate list + // that is optimized for re-transmission of the whole list, but not for lookup. template void forEach(Func&& func) const; @@ -331,13 +360,16 @@ class HttpHeaders { void takeOwnership(kj::String&& string); void takeOwnership(kj::Array&& chars); void takeOwnership(HttpHeaders&& otherHeaders); - // Takes overship of a string so that it lives until the HttpHeaders object is destroyed. Useful + // Takes ownership of a string so that it lives until the HttpHeaders object is destroyed. Useful // when you've passed a dynamic value to set() or add() or parse*(). struct Request { HttpMethod method; kj::StringPtr url; }; + struct ConnectRequest { + kj::StringPtr authority; + }; struct Response { uint statusCode; kj::StringPtr statusText; @@ -376,12 +408,15 @@ class HttpHeaders { using RequestOrProtocolError = kj::OneOf; using ResponseOrProtocolError = kj::OneOf; + using RequestConnectOrProtocolError = kj::OneOf; RequestOrProtocolError tryParseRequest(kj::ArrayPtr content); + RequestConnectOrProtocolError tryParseRequestOrConnect(kj::ArrayPtr content); ResponseOrProtocolError tryParseResponse(kj::ArrayPtr content); + // Parse an HTTP header blob and add all the headers to this object. // - // `content` should be all text from the start of the request to the first occurrance of two + // `content` should be all text from the start of the request to the first occurrence of two // newlines in a row -- including the first of these two newlines, but excluding the second. // // The parse is performed with zero copies: The callee clobbers `content` with '\0' characters @@ -393,6 +428,8 @@ class HttpHeaders { kj::String serializeRequest(HttpMethod method, kj::StringPtr url, kj::ArrayPtr connectionHeaders = nullptr) const; + kj::String serializeConnectRequest(kj::StringPtr authority, + kj::ArrayPtr connectionHeaders = nullptr) const; kj::String serializeResponse(uint statusCode, kj::StringPtr statusText, kj::ArrayPtr connectionHeaders = nullptr) const; // **Most applications will not use these methods; they are called by the HTTP client and server @@ -476,6 +513,17 @@ class HttpInputStream { // The returned struct contains pointers directly into a buffer that is invalidated on the next // message read. + struct Connect { + kj::StringPtr authority; + const HttpHeaders& headers; + kj::Own body; + }; + virtual kj::Promise> readRequestAllowingConnect() = 0; + // Reads one HTTP request from the input stream. + // + // The returned struct contains pointers directly into a buffer that is invalidated on the next + // message read. + struct Response { uint statusCode; kj::StringPtr statusText; @@ -517,6 +565,16 @@ class EntropySource { virtual void generate(kj::ArrayPtr buffer) = 0; }; +struct CompressionParameters { + // These are the parameters for `Sec-WebSocket-Extensions` permessage-deflate extension. + // Since we cannot distinguish the client/server in `upgradeToWebSocket`, we use the prefixes + // `inbound` and `outbound` instead. + bool outboundNoContextTakeover = false; + bool inboundNoContextTakeover = false; + kj::Maybe outboundMaxWindowBits = nullptr; + kj::Maybe inboundMaxWindowBits = nullptr; +}; + class WebSocket { // Interface representincg an open WebSocket session. // @@ -556,6 +614,19 @@ class WebSocket { // resolves, but send() or receive() will throw DISCONNECTED when appropriate. See also // kj::AsyncOutputStream::whenWriteDisconnected().) + struct ProtocolError { + // Represents a protocol error, such as a bad opcode or oversize message. + + uint statusCode; + // Suggested WebSocket status code that should be used when returning an error to the client. + // + // Most errors are 1002; an oversize message will be 1009. + + kj::StringPtr description; + // An error description safe for all the world to see. This should be at most 123 bytes so that + // it can be used as the body of a Close frame (RFC 6455 sections 5.5 and 5.5.1). + }; + struct Close { uint16_t code; kj::String reason; @@ -586,6 +657,111 @@ class WebSocket { virtual uint64_t sentByteCount() = 0; virtual uint64_t receivedByteCount() = 0; + + enum ExtensionsContext { + // Indicate whether a Sec-WebSocket-Extension header should be rendered for use in request + // headers or response headers. + REQUEST, + RESPONSE + }; + virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) { return nullptr; } + // If pumpTo() / tryPumpFrom() is able to be optimized only if the other WebSocket is using + // certain extensions (e.g. compression settings), then this method returns what those extensions + // are. For example, matching extensions between standard WebSockets allows pumping to be + // implemented by pumping raw bytes between network connections, without reading individual frames. + // + // A null return value indicates that there is no preference. A non-null return value containing + // an empty string indicates a preference for no extensions to be applied. +}; + +using TlsStarterCallback = kj::Maybe(kj::StringPtr)>>; +struct HttpConnectSettings { + bool useTls = false; + // Requests to automatically establish a TLS session over the connection. The remote party + // will be expected to present a valid certificate matching the requested hostname. + kj::Maybe tlsStarter; + // This is an output parameter. It doesn't need to be set. But if it is set, then it may get + // filled with a callback function. It will get filled with `nullptr` if any of the following + // are true: + // + // * kj is not built with TLS support + // * the underlying HttpClient does not support the startTls mechanism + // * `useTls` has been set to `true` and so TLS has already been started + // + // The callback function itself can be called to initiate a TLS handshake on the connection in + // between write() operations. It is not allowed to initiate a TLS handshake while a write + // operation or a pump operation to the connection exists. Read operations are not subject to + // the same constraint, however: implementations are required to be able to handle TLS + // initiation while a read operation or pump operation from the connection exists. Once the + // promise returned from the callback is fulfilled, the connection has become a secure stream, + // and write operations are once again permitted. The StringPtr parameter to the callback, + // expectedServerHostname may be dropped after the function synchronously returns. + // + // The PausableReadAsyncIoStream class defined below can be used to ensure that read operations + // are not pending when the tlsStarter is invoked. + // + // This mechanism is required for certain protocols, more info can be found on + // https://en.wikipedia.org/wiki/Opportunistic_TLS. +}; + + +class PausableReadAsyncIoStream final: public kj::AsyncIoStream { + // A custom AsyncIoStream which can pause pending reads. This is used by startTls to pause a + // a read before TLS is initiated. + // + // TODO(cleanup): this class should be rewritten to use a CRTP mixin approach so that pumps + // can be optimised once startTls is invoked. + class PausableRead; +public: + PausableReadAsyncIoStream(kj::Own stream) + : inner(kj::mv(stream)), currentlyWriting(false), currentlyReading(false) {} + + _::Deferred> trackRead(); + + _::Deferred> trackWrite(); + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + + kj::Promise tryReadImpl(void* buffer, size_t minBytes, size_t maxBytes); + + kj::Maybe tryGetLength() override; + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override; + + kj::Promise write(const void* buffer, size_t size) override; + + kj::Promise write(kj::ArrayPtr> pieces) override; + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override; + + kj::Promise whenWriteDisconnected() override; + + void shutdownWrite() override; + + void abortRead() override; + + kj::Maybe getFd() const override; + + void pause(); + + void unpause(); + + bool getCurrentlyReading(); + + bool getCurrentlyWriting(); + + kj::Own takeStream(); + + void replaceStream(kj::Own stream); + + void reject(kj::Exception&& exc); + +private: + kj::Own inner; + kj::Maybe maybePausableRead; + bool currentlyWriting; + bool currentlyReading; }; class HttpClient { @@ -614,7 +790,7 @@ class HttpClient { // Content-Length: 0. kj::Promise response; - // Promise for the eventual respnose. + // Promise for the eventual response. }; virtual Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, @@ -645,9 +821,45 @@ class HttpClient { // `url` and `headers` need only remain valid until `openWebSocket()` returns (they can be // stack-allocated). - virtual kj::Promise> connect(kj::StringPtr host); - // Handles CONNECT requests. Only relevant for proxy clients. Default implementation throws - // UNIMPLEMENTED. + struct ConnectRequest { + struct Status { + uint statusCode; + kj::String statusText; + kj::Own headers; + kj::Maybe> errorBody; + // If the connect request is rejected, the statusCode can be any HTTP status code + // outside the 200-299 range and errorBody *may* be specified if there is a rejection + // payload. + + // TODO(perf): Having Status own the statusText and headers is a bit unfortunate. + // Ideally we could have these be non-owned so that the headers object could just + // point directly into HttpOutputStream's buffer and not be copied. That's a bit + // more difficult to with CONNECT since the lifetimes of the buffers are a little + // different than with regular HTTP requests. It should still be possible but for + // now copying and owning the status text and headers is easier. + + Status(uint statusCode, + kj::String statusText, + kj::Own headers, + kj::Maybe> errorBody = nullptr) + : statusCode(statusCode), + statusText(kj::mv(statusText)), + headers(kj::mv(headers)), + errorBody(kj::mv(errorBody)) {} + }; + + kj::Promise status; + kj::Own connection; + }; + + virtual ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings); + // Handles CONNECT requests. + // + // `host` must specify both the host and port (e.g. "example.org:1234"). + // + // The `host` and `headers` need only remain valid until `connect()` returns (it can be + // stack-allocated). }; class HttpService { @@ -675,9 +887,22 @@ class HttpService { // // `statusText` and `headers` need only remain valid until send() returns (they can be // stack-allocated). + // + // `send()` may only be called a single time. Calling it a second time will cause an exception + // to be thrown. virtual kj::Own acceptWebSocket(const HttpHeaders& headers) = 0; // If headers.isWebSocket() is true then you can call acceptWebSocket() instead of send(). + // + // If the request is an invalid WebSocket request (e.g., it has an Upgrade: websocket header, + // but other WebSocket-related headers are invalid), `acceptWebSocket()` will throw an + // exception, and the HttpServer will return a 400 Bad Request response and close the + // connection. In this circumstance, the HttpServer will ignore any exceptions which propagate + // from the `HttpService::request()` promise. `HttpServerErrorHandler::handleApplicationError()` + // will not be invoked, and the HttpServer's listen task will be fulfilled normally. + // + // `acceptWebSocket()` may only be called a single time. Calling it a second time will cause an + // exception to be thrown. kj::Promise sendError(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers); @@ -704,10 +929,49 @@ class HttpService { // // Request processing can be canceled by dropping the returned promise. HttpServer may do so if // the client disconnects prematurely. + // + // The implementation of `request()` should usually not try to use `response` in any way in + // exception-handling code, because it is often not possible to tell whether `Response::send()` or + // `Response::acceptWebSocket()` has already been called. Instead, to generate error HTTP + // responses for the client, implement an HttpServerErrorHandler and pass it to the HttpServer via + // HttpServerSettings. If the `HttpService::request()` promise rejects and no response has yet + // been sent, `HttpServerErrorHandler::handleApplicationError()` will be passed a non-null + // `Maybe` parameter. + + class ConnectResponse { + public: + virtual void accept( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers) = 0; + // Signals acceptance of the CONNECT tunnel. + + virtual kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) = 0; + // Signals rejection of the CONNECT tunnel. + }; - virtual kj::Promise> connect(kj::StringPtr host); - // Handles CONNECT requests. Only relevant for proxy services. Default implementation throws - // UNIMPLEMENTED. + virtual kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + HttpConnectSettings settings); + // Handles CONNECT requests. + // + // The `host` must include host and port. + // + // `host` and `headers` are invalidated when accept or reject is called on the ConnectResponse + // or when the returned promise resolves, whichever comes first. + // + // The connection is provided to support pipelining. Writes to the connection will be blocked + // until one of either accept() or reject() is called on tunnel. Reads from the connection are + // permitted at any time. + // + // Request processing can be canceled by dropping the returned promise. HttpServer may do so if + // the client disconnects prematurely. }; class HttpClientErrorHandler { @@ -749,6 +1013,29 @@ struct HttpClientSettings { kj::Maybe errorHandler = nullptr; // Customize how protocol errors are handled by the HttpClient. If null, HttpClientErrorHandler's // default implementation will be used. + + enum WebSocketCompressionMode { + NO_COMPRESSION, + MANUAL_COMPRESSION, // Lets the application decide the compression configuration (if any). + AUTOMATIC_COMPRESSION, // Automatically includes the compression header in the WebSocket request. + }; + WebSocketCompressionMode webSocketCompressionMode = NO_COMPRESSION; + + kj::Maybe tlsContext; + // A reference to a TLS context that will be used when tlsStarter is invoked. +}; + +class WebSocketErrorHandler { +public: + virtual kj::Exception handleWebSocketProtocolError(WebSocket::ProtocolError protocolError); + // Handles low-level protocol errors in received WebSocket data. + // + // This is called when the WebSocket peer sends us bad data *after* a successful WebSocket + // upgrade, e.g. a continuation frame without a preceding start frame, a frame with an unknown + // opcode, or similar. + // + // You would override this method in order to customize the exception. You cannot prevent the + // exception from being thrown. }; kj::Own newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, @@ -817,7 +1104,9 @@ kj::Own newHttpInputStream( // continue reading from `input` in a reliable way. kj::Own newWebSocket(kj::Own stream, - kj::Maybe maskEntropySource); + kj::Maybe maskEntropySource, + kj::Maybe compressionConfig = nullptr, + kj::Maybe errorHandler = nullptr); // Create a new WebSocket on top of the given stream. It is assumed that the HTTP -> WebSocket // upgrade handshake has already occurred (or is not needed), and messages can immediately be // sent and received on the stream. Normally applications would not call this directly. @@ -829,6 +1118,13 @@ kj::Own newWebSocket(kj::Own stream, // purpose of the mask is to prevent badly-written HTTP proxies from interpreting "things that look // like HTTP requests" in a message as being actual HTTP requests, which could result in cache // poisoning. See RFC6455 section 10.3. +// +// `compressionConfig` is an optional argument that allows us to specify how the WebSocket should +// compress and decompress messages. The configuration is determined by the +// `Sec-WebSocket-Extensions` header during WebSocket negotiation. +// +// `errorHandler` is an optional argument that lets callers throw custom exceptions for WebSocket +// protocol errors. struct WebSocketPipe { kj::Own ends[2]; @@ -865,6 +1161,16 @@ struct HttpServerSettings { kj::Maybe callbacks = nullptr; // Additional optional callbacks used to control some server behavior. + + kj::Maybe webSocketErrorHandler = nullptr; + // Customize exceptions thrown on WebSocket protocol errors. + + enum WebSocketCompressionMode { + NO_COMPRESSION, + MANUAL_COMPRESSION, // Gives the application more control when considering whether to compress. + AUTOMATIC_COMPRESSION, // Will perform compression parameter negotiation if client requests it. + }; + WebSocketCompressionMode webSocketCompressionMode = NO_COMPRESSION; }; class HttpServerErrorHandler { @@ -899,6 +1205,12 @@ class HttpServerErrorHandler { // // Also unlike `HttpService::request()`, it is okay to return kj::READY_NOW without calling // `response.send()`. In this case, no response will be sent, and the connection will be closed. + + virtual void handleListenLoopException(kj::Exception&& exception); + // Override this function to customize error handling for individual connections in the + // `listenHttp()` overload which accepts a ConnectionReceiver reference. + // + // The default handler uses KJ_LOG() to log the exception as an error. }; class HttpServerCallbacks { @@ -918,6 +1230,9 @@ class HttpServer final: private kj::TaskSet::ErrorHandler { public: typedef HttpServerSettings Settings; typedef kj::Function(kj::AsyncIoStream&)> HttpServiceFactory; + class SuspendableRequest; + typedef kj::Function>(SuspendableRequest&)> + SuspendableHttpServiceFactory; HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, HttpService& service, Settings settings = Settings()); @@ -947,7 +1262,7 @@ class HttpServer final: private kj::TaskSet::ErrorHandler { // Reads HTTP requests from the given connection and directs them to the handler. A successful // completion of the promise indicates that all requests received on the connection resulted in // a complete response, and the client closed the connection gracefully or drain() was called. - // The promise throws if an unparseable request is received or if some I/O error occurs. Dropping + // The promise throws if an unparsable request is received or if some I/O error occurs. Dropping // the returned promise will cancel all I/O on the connection and cancel any in-flight requests. kj::Promise listenHttpCleanDrain(kj::AsyncIoStream& connection); @@ -958,6 +1273,56 @@ class HttpServer final: private kj::TaskSet::ErrorHandler { // caller should close it without any further reads/writes. Note this only ever returns `true` // if you called `drain()` -- otherwise this server would keep handling the connection. + class SuspendedRequest { + // SuspendedRequest is a representation of a request immediately after parsing the method line and + // headers. You can obtain one of these by suspending a request by calling + // SuspendableRequest::suspend(), then later resume the request with another call to + // listenHttpCleanDrain(). + + public: + // Nothing, this is an opaque type. + + private: + SuspendedRequest(kj::Array, kj::ArrayPtr, kj::OneOf, kj::StringPtr, HttpHeaders); + + kj::Array buffer; + // A buffer containing at least the request's method, URL, and headers, and possibly content + // thereafter. + + kj::ArrayPtr leftover; + // Pointer to the end of the request headers. If this has a non-zero length, then our buffer + // contains additional content, presumably the head of the request body. + + kj::OneOf method; + kj::StringPtr url; + HttpHeaders headers; + // Parsed request front matter. `url` and `headers` both store pointers into `buffer`. + + friend class HttpServer; + }; + + kj::Promise listenHttpCleanDrain(kj::AsyncIoStream& connection, + SuspendableHttpServiceFactory factory, + kj::Maybe suspendedRequest = nullptr); + // Like listenHttpCleanDrain(), but allows you to suspend requests. + // + // When this overload is in use, the HttpServer's default HttpService or HttpServiceFactory is not + // used. Instead, the HttpServer reads the request method line and headers, then calls `factory` + // with a SuspendableRequest representing the request parsed so far. The factory may then return + // a kj::Own for that specific request, or it may call SuspendableRequest::suspend() + // and return nullptr. (It is an error for the factory to return nullptr without also calling + // suspend(); this will result in a rejected listenHttpCleanDrain() promise.) + // + // If the factory chooses to suspend, the listenHttpCleanDrain() promise is resolved with false + // at the earliest opportunity. + // + // SuspendableRequest::suspend() returns a SuspendedRequest. You can resume this request later by + // calling this same listenHttpCleanDrain() overload with the original connection stream, and the + // SuspendedRequest in question. + // + // This overload of listenHttpCleanDrain() implements draining, as documented above. Note that the + // returned promise will resolve to false (not clean) if a request is suspended. + private: class Connection; @@ -982,6 +1347,38 @@ class HttpServer final: private kj::TaskSet::ErrorHandler { kj::Promise listenLoop(kj::ConnectionReceiver& port); void taskFailed(kj::Exception&& exception) override; + + kj::Promise listenHttpImpl(kj::AsyncIoStream& connection, bool wantCleanDrain); + kj::Promise listenHttpImpl(kj::AsyncIoStream& connection, + SuspendableHttpServiceFactory factory, + kj::Maybe suspendedRequest, + bool wantCleanDrain); +}; + +class HttpServer::SuspendableRequest { + // Interface passed to the SuspendableHttpServiceFactory parameter of listenHttpCleanDrain(). + +public: + kj::OneOf method; + kj::StringPtr url; + const HttpHeaders& headers; + // Parsed request front matter, so the implementer can decide whether to suspend the request. + + SuspendedRequest suspend(); + // Signal to the HttpServer that the current request loop should be exited. Return a + // SuspendedRequest, containing HTTP method, URL, and headers access, along with the actual header + // buffer. The request can be later resumed with a call to listenHttpCleanDrain() using the same + // connection. + +private: + explicit SuspendableRequest( + Connection& connection, kj::OneOf method, kj::StringPtr url, const HttpHeaders& headers) + : method(method), url(url), headers(headers), connection(connection) {} + KJ_DISALLOW_COPY_AND_MOVE(SuspendableRequest); + + Connection& connection; + + friend class Connection; }; // ======================================================================================= @@ -992,10 +1389,22 @@ inline void HttpHeaderId::requireFrom(const HttpHeaderTable& table) const { "the provided HttpHeaderId is from the wrong HttpHeaderTable"); } -inline kj::Own HttpHeaderTable::Builder::build() { return kj::mv(table); } +inline kj::Own HttpHeaderTable::Builder::build() { + table->buildStatus = BuildStatus::FINISHED; + return kj::mv(table); +} inline HttpHeaderTable& HttpHeaderTable::Builder::getFutureTable() { return *table; } inline uint HttpHeaderTable::idCount() const { return namesById.size(); } +inline bool HttpHeaderTable::isReady() const { + switch (buildStatus) { + case BuildStatus::UNSTARTED: return true; + case BuildStatus::BUILDING: return false; + case BuildStatus::FINISHED: return true; + } + + KJ_UNREACHABLE; +} inline kj::StringPtr HttpHeaderTable::idToString(HttpHeaderId id) const { id.requireFrom(*this); @@ -1039,4 +1448,57 @@ inline void HttpHeaders::forEach(Func1&& func1, Func2&& func2) const { } } +// ======================================================================================= +namespace _ { // private implementation details for WebSocket compression + +kj::ArrayPtr splitNext(kj::ArrayPtr& cursor, char delimiter); + +void stripLeadingAndTrailingSpace(ArrayPtr& str); + +kj::Vector> splitParts(kj::ArrayPtr input, char delim); + +struct KeyMaybeVal { + ArrayPtr key; + kj::Maybe> val; +}; + +kj::Array toKeysAndVals(const kj::ArrayPtr>& params); + +struct UnverifiedConfig { + // An intermediate representation of the final `CompressionParameters` struct; used during parsing. + // We use it to ensure the structure of an offer is generally correct, see + // `populateUnverifiedConfig()` for details. + bool clientNoContextTakeover = false; + bool serverNoContextTakeover = false; + kj::Maybe> clientMaxWindowBits = nullptr; + kj::Maybe> serverMaxWindowBits = nullptr; +}; + +kj::Maybe populateUnverifiedConfig(kj::Array& params); + +kj::Maybe validateCompressionConfig(UnverifiedConfig&& config, + bool isAgreement); + +kj::Vector findValidExtensionOffers(StringPtr offers); + +kj::String generateExtensionRequest(const ArrayPtr& extensions); + +kj::Maybe tryParseExtensionOffers(StringPtr offers); + +kj::Maybe tryParseAllExtensionOffers(StringPtr offers, + CompressionParameters manualConfig); + +kj::Maybe compareClientAndServerConfigs(CompressionParameters requestConfig, + CompressionParameters manualConfig); + +kj::String generateExtensionResponse(const CompressionParameters& parameters); + +kj::OneOf tryParseExtensionAgreement( + const Maybe& clientOffer, + StringPtr agreedParameters); + +}; // namespace _ (private) + } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/readiness-io.h b/libs/EXTERNAL/capnproto/c++/src/kj/compat/readiness-io.h index d55d8474b01..2ed94684167 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/readiness-io.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/readiness-io.h @@ -23,6 +23,8 @@ #include +KJ_BEGIN_HEADER + namespace kj { class ReadyInputStreamWrapper { @@ -34,7 +36,7 @@ class ReadyInputStreamWrapper { public: ReadyInputStreamWrapper(AsyncInputStream& input); ~ReadyInputStreamWrapper() noexcept(false); - KJ_DISALLOW_COPY(ReadyInputStreamWrapper); + KJ_DISALLOW_COPY_AND_MOVE(ReadyInputStreamWrapper); kj::Maybe read(kj::ArrayPtr dst); // Reads bytes into `dst`, returning the number of bytes read. Returns zero only at EOF. Returns @@ -43,6 +45,9 @@ class ReadyInputStreamWrapper { kj::Promise whenReady(); // Returns a promise that resolves when read() will return non-null. + bool isAtEnd() { return eof; } + // Returns true if read() would return zero. + private: AsyncInputStream& input; kj::ForkedPromise pumpTask = nullptr; @@ -62,11 +67,11 @@ class ReadyOutputStreamWrapper { public: ReadyOutputStreamWrapper(AsyncOutputStream& output); ~ReadyOutputStreamWrapper() noexcept(false); - KJ_DISALLOW_COPY(ReadyOutputStreamWrapper); + KJ_DISALLOW_COPY_AND_MOVE(ReadyOutputStreamWrapper); kj::Maybe write(kj::ArrayPtr src); - // Writes bytes from `src`, returning the number of bytes written. Never returns zero. Returns - // nullptr if not ready. + // Writes bytes from `src`, returning the number of bytes written. Never returns zero for + // a non-empty `src`. Returns nullptr if not ready. kj::Promise whenReady(); // Returns a promise that resolves when write() will return non-null. @@ -122,3 +127,5 @@ class ReadyOutputStreamWrapper::Cork { }; } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls-test.c++ index 7c5cf03800d..dddefa5747f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls-test.c++ @@ -27,6 +27,8 @@ #include "tls.h" +#include "http.h" + #include #include @@ -467,6 +469,56 @@ KJ_TEST("TLS basics") { auto server = serverPromise.wait(test.io.waitScope); test.testConnection(*client, *server); + + // Test clean shutdown. + { + auto eofPromise = server->readAllText(); + KJ_EXPECT(!eofPromise.poll(test.io.waitScope)); + client->shutdownWrite(); + KJ_ASSERT(eofPromise.poll(test.io.waitScope)); + KJ_EXPECT(eofPromise.wait(test.io.waitScope) == ""_kj); + } + + // Test UNCLEAN shutdown in other direction. + { + auto eofPromise = client->readAllText(); + KJ_EXPECT(!eofPromise.poll(test.io.waitScope)); + { auto drop = kj::mv(server); } + KJ_EXPECT(eofPromise.poll(test.io.waitScope)); + KJ_EXPECT_THROW(DISCONNECTED, eofPromise.wait(test.io.waitScope)); + } +} + +KJ_TEST("TLS half-duplex") { + // Test shutting down one direction of a connection but continuing to stream in the other + // direction. + + TlsTest test; + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + client->shutdownWrite(); + KJ_EXPECT(server->readAllText().wait(test.io.waitScope) == ""); + + for (uint i = 0; i < 100; i++) { + char buffer[7]; + auto writePromise = server->write("foobar", 6); + auto readPromise = client->read(buffer, 6); + writePromise.wait(test.io.waitScope); + readPromise.wait(test.io.waitScope); + buffer[6] = '\0'; + KJ_ASSERT(kj::StringPtr(buffer, 6) == "foobar"); + } + + server->shutdownWrite(); + KJ_EXPECT(client->readAllText().wait(test.io.waitScope) == ""); } KJ_TEST("TLS peer identity") { @@ -589,10 +641,10 @@ kj::Promise readN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t co --count; auto buf = kj::heapString(text.size()); auto promise = stream.read(buf.begin(), buf.size()); - return promise.then(kj::mvCapture(buf, [&stream, text, count](kj::String buf) { + return promise.then([&stream, text, buf=kj::mv(buf), count]() { KJ_ASSERT(buf == text, buf, text, count); return readN(stream, text, count); - })); + }); } KJ_TEST("TLS full duplex") { @@ -619,7 +671,10 @@ KJ_TEST("TLS full duplex") { auto writeUp = writeN(*client, "foo", 10000); auto readDown = readN(*client, "bar", 10000); +#if !(_WIN32 && __clang__) + // TODO(someday): work out why this expectation fails even with the above fix KJ_EXPECT(!writeUp.poll(test.io.waitScope)); +#endif KJ_EXPECT(!readDown.poll(test.io.waitScope)); auto writeDown = writeN(*server, "bar", 10000); @@ -651,7 +706,7 @@ KJ_TEST("TLS SNI") { TestSniCallback callback; serverOptions.sniCallback = callback; - TlsTest test(TlsTest::defaultClient(), serverOptions); + TlsTest test(TlsTest::defaultClient(), kj::mv(serverOptions)); ErrorNexus e; auto pipe = test.io.provider->newTwoWayPipe(); @@ -667,7 +722,8 @@ KJ_TEST("TLS SNI") { KJ_ASSERT(callback.callCount == 1); } -void expectInvalidCert(kj::StringPtr hostname, TlsCertificate cert, kj::StringPtr message) { +void expectInvalidCert(kj::StringPtr hostname, TlsCertificate cert, + kj::StringPtr message, kj::Maybe altMessage = nullptr) { TlsKeypair keypair = { TlsPrivateKey(HOST_KEY), kj::mv(cert) }; TlsContext::Options serverOpts; serverOpts.defaultKeypair = keypair; @@ -679,37 +735,74 @@ void expectInvalidCert(kj::StringPtr hostname, TlsCertificate cert, kj::StringPt auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), hostname)); auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); - KJ_EXPECT_THROW_MESSAGE(message, clientPromise.wait(test.io.waitScope)); + clientPromise.then([](kj::Own) { + KJ_FAIL_EXPECT("expected exception"); + }, [message, altMessage](kj::Exception&& e) { + if (kj::_::hasSubstring(e.getDescription(), message)) { + return; + } + + KJ_IF_MAYBE(a, altMessage) { + if (kj::_::hasSubstring(e.getDescription(), *a)) { + return; + } + } + + KJ_FAIL_EXPECT("exception didn't contain expected message", message, + altMessage.orDefault(nullptr), e); + }).wait(test.io.waitScope); } KJ_TEST("TLS certificate validation") { + // Where we've given two possible error texts below, it's because OpenSSL v1 produces the former + // text while v3 produces the latter. Note that as of this writing, our Windows CI build claims + // to be v3 but produces v1 text, for reasons I don't care to investigate. expectInvalidCert("wrong.com", TlsCertificate(kj::str(VALID_CERT, INTERMEDIATE_CERT)), - "Hostname mismatch"); + "Hostname mismatch"_kj, "hostname mismatch"_kj); expectInvalidCert("example.com", TlsCertificate(VALID_CERT), - "unable to get local issuer certificate"); + "unable to get local issuer certificate"_kj); expectInvalidCert("example.com", TlsCertificate(kj::str(EXPIRED_CERT, INTERMEDIATE_CERT)), - "certificate has expired"); + "certificate has expired"_kj); expectInvalidCert("example.com", TlsCertificate(SELF_SIGNED_CERT), - "self signed certificate"); + "self signed certificate"_kj, "self-signed certificate"_kj); } // BoringSSL seems to print error messages differently. #ifdef OPENSSL_IS_BORINGSSL -#define SSL_MESSAGE(interesting, boring) boring +#define SSL_MESSAGE_DIFFERENT_IN_BORINGSSL(interesting, boring) boring #else -#define SSL_MESSAGE(interesting, boring) interesting +#define SSL_MESSAGE_DIFFERENT_IN_BORINGSSL(interesting, boring) interesting #endif KJ_TEST("TLS client certificate verification") { - TlsContext::Options serverOptions = TlsTest::defaultServer(); - TlsContext::Options clientOptions = TlsTest::defaultClient(); + enum class VerifyClients { + YES, + NO + }; + auto makeServerOptionsForClient = []( + const TlsContext::Options& clientOptions, + VerifyClients verifyClients + ) { + TlsContext::Options serverOptions = TlsTest::defaultServer(); + serverOptions.verifyClients = verifyClients == VerifyClients::YES; - serverOptions.verifyClients = true; - serverOptions.trustedCertificates = clientOptions.trustedCertificates; + // Share the certs between the client and server. + serverOptions.trustedCertificates = clientOptions.trustedCertificates; + + return serverOptions; + }; + + TlsKeypair selfSignedKeypair = { TlsPrivateKey(HOST_KEY), TlsCertificate(SELF_SIGNED_CERT) }; + TlsKeypair altKeypair = { + TlsPrivateKey(HOST_KEY2), + TlsCertificate(kj::str(VALID_CERT2, INTERMEDIATE_CERT)), + }; - // No certificate loaded in the client: fail { - TlsTest test(clientOptions, serverOptions); + // No certificate loaded in the client: fail + auto clientOptions = TlsTest::defaultClient(); + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::YES); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); auto pipe = test.io.provider->newTwoWayPipe(); @@ -720,24 +813,24 @@ KJ_TEST("TLS client certificate verification") { }); auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1])); - KJ_EXPECT_THROW_MESSAGE( - SSL_MESSAGE("peer did not return a certificate", - "PEER_DID_NOT_RETURN_A_CERTIFICATE"), - serverPromise.wait(test.io.waitScope)); -#if !KJ_NO_EXCEPTIONS // if exceptions are disabled, we're now in a bad state because - // KJ_EXPECT_THROW_MESSAGE() runs in a forked child process. - KJ_EXPECT_THROW_MESSAGE( - SSL_MESSAGE("alert", // "alert handshake failure" or "alert certificate required" - "CERTIFICATE_REQUIRED"), - clientPromise.wait(test.io.waitScope)); -#endif + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + SSL_MESSAGE_DIFFERENT_IN_BORINGSSL("peer did not return a certificate", + "PEER_DID_NOT_RETURN_A_CERTIFICATE"), + serverPromise.ignoreResult().wait(test.io.waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + SSL_MESSAGE_DIFFERENT_IN_BORINGSSL( + "alert", // "alert handshake failure" or "alert certificate required" + "ALERT"), // "ALERT_HANDSHAKE_FAILURE" or "ALERT_CERTIFICATE_REQUIRED" + clientPromise.ignoreResult().wait(test.io.waitScope)); } - // Self-signed certificate loaded in the client: fail - TlsKeypair selfSignedKeypair = { TlsPrivateKey(HOST_KEY), TlsCertificate(SELF_SIGNED_CERT) }; - clientOptions.defaultKeypair = selfSignedKeypair; { - TlsTest test(clientOptions, serverOptions); + // Self-signed certificate loaded in the client: fail + auto clientOptions = TlsTest::defaultClient(); + clientOptions.defaultKeypair = selfSignedKeypair; + + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::YES); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); auto pipe = test.io.provider->newTwoWayPipe(); @@ -748,27 +841,24 @@ KJ_TEST("TLS client certificate verification") { }); auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1])); - KJ_EXPECT_THROW_MESSAGE( - SSL_MESSAGE("certificate verify failed", - "CERTIFICATE_VERIFY_FAILED"), - serverPromise.wait(test.io.waitScope)); -#if !KJ_NO_EXCEPTIONS // if exceptions are disabled, we're now in a bad state because - // KJ_EXPECT_THROW_MESSAGE() runs in a forked child process. - KJ_EXPECT_THROW_MESSAGE( - SSL_MESSAGE("alert unknown ca", - "TLSV1_ALERT_UNKNOWN_CA"), - clientPromise.wait(test.io.waitScope)); -#endif + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + SSL_MESSAGE_DIFFERENT_IN_BORINGSSL("certificate verify failed", + "CERTIFICATE_VERIFY_FAILED"), + serverPromise.ignoreResult().wait(test.io.waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + SSL_MESSAGE_DIFFERENT_IN_BORINGSSL("alert unknown ca", + "TLSV1_ALERT_UNKNOWN_CA"), + clientPromise.ignoreResult().wait(test.io.waitScope)); } - // Trusted certificate loaded in the client: success. - TlsKeypair altKeypair = { - TlsPrivateKey(HOST_KEY2), - TlsCertificate(kj::str(VALID_CERT2, INTERMEDIATE_CERT)) - }; - clientOptions.defaultKeypair = altKeypair; { - TlsTest test(clientOptions, serverOptions); + // Trusted certificate loaded in the client: success. + auto clientOptions = TlsTest::defaultClient(); + clientOptions.defaultKeypair = altKeypair; + + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::YES); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); + ErrorNexus e; auto pipe = test.io.provider->newTwoWayPipe(); @@ -787,10 +877,12 @@ KJ_TEST("TLS client certificate verification") { test.testConnection(*client, *server.stream); } - // If verifyClients is off, client certificate is ignored, even if trusted. - serverOptions.verifyClients = false; { - TlsTest test(clientOptions, serverOptions); + // If verifyClients is off, client certificate is ignored, even if trusted. + auto clientOptions = TlsTest::defaultClient(); + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::NO); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); + ErrorNexus e; auto pipe = test.io.provider->newTwoWayPipe(); @@ -806,10 +898,14 @@ KJ_TEST("TLS client certificate verification") { KJ_EXPECT(!id->hasCertificate()); } - // Non-trusted keys are ignored too (not errors). - clientOptions.defaultKeypair = selfSignedKeypair; { - TlsTest test(clientOptions, serverOptions); + // Non-trusted keys are ignored too (not errors). + auto clientOptions = TlsTest::defaultClient(); + clientOptions.defaultKeypair = selfSignedKeypair; + + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::NO); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); + ErrorNexus e; auto pipe = test.io.provider->newTwoWayPipe(); @@ -1096,6 +1192,73 @@ KJ_TEST("TLS receiver does not stall on hung client") { KJ_EXPECT(!extraAcceptPromise.poll(test.io.waitScope)); } +kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { + if (expected.size() == 0) return kj::READY_NOW; + + auto buffer = kj::heapArray(expected.size()); + + auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { + if (amount == 0) { + KJ_FAIL_ASSERT("expected data never sent", expected); + } + + auto actual = buffer.slice(0, amount); + if (memcmp(actual.begin(), expected.begin(), actual.size()) != 0) { + KJ_FAIL_ASSERT("data from stream doesn't match expected", expected, actual); + } + + return expectRead(in, expected.slice(amount)); + }); +} + +kj::Promise expectEnd(kj::AsyncInputStream& in) { + static char buffer; + + auto promise = in.tryRead(&buffer, 1, 1); + return promise.then([](size_t amount) { + KJ_ASSERT(amount == 0, "expected EOF"); + }); +} + +KJ_TEST("NetworkHttpClient connect with tlsStarter") { + auto io = kj::setupAsyncIo(); + auto listener1 = io.provider->getNetwork().parseAddress("127.0.0.1", 0) + .wait(io.waitScope)->listen(); + + auto acceptLoop KJ_UNUSED = listener1->accept().then([](Own stream) { + return stream->pumpTo(*stream).attach(kj::mv(stream)).ignoreResult(); + }).eagerlyEvaluate(nullptr); + + HttpClientSettings clientSettings; + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + TlsContext tls; + + auto tlsNetwork = tls.wrapNetwork(io.provider->getNetwork()); + clientSettings.tlsContext = tls; + auto client = newHttpClient(clientTimer, headerTable, + io.provider->getNetwork(), *tlsNetwork, clientSettings); + kj::HttpConnectSettings httpConnectSettings = { false, nullptr }; + kj::TlsStarterCallback tlsStarter; + httpConnectSettings.tlsStarter = tlsStarter; + auto request = client->connect( + kj::str("127.0.0.1:", listener1->getPort()), HttpHeaders(headerTable), httpConnectSettings); + + KJ_ASSERT(tlsStarter != nullptr); + + auto buf = kj::heapArray(4); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(request.connection->write("hello", 5)); + promises.add(expectRead(*request.connection, "hello"_kj)); + kj::joinPromisesFailFast(promises.finish()) + .then([io=kj::mv(request.connection)]() mutable { + io->shutdownWrite(); + return expectEnd(*io).attach(kj::mv(io)); + }).attach(kj::mv(listener1)).wait(io.waitScope); +} + #ifdef KJ_EXTERNAL_TESTS KJ_TEST("TLS to capnproto.org") { kj::AsyncIoContext io = setupAsyncIo(); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls.c++ index 05bf5d53928..6affeb1f671 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls.c++ @@ -51,19 +51,33 @@ namespace kj { namespace { -KJ_NORETURN(void throwOpensslError()); -void throwOpensslError() { - // Call when an OpenSSL function returns an error code to convert that into an exception and - // throw it. +kj::Exception getOpensslError() { + // Call when an OpenSSL function returns an error code to convert that into an exception. kj::Vector lines; while (unsigned long long error = ERR_get_error()) { +#ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING + // OpenSSL 3.0+ reports unexpected disconnects this way. + if (ERR_GET_REASON(error) == SSL_R_UNEXPECTED_EOF_WHILE_READING) { + return KJ_EXCEPTION(DISCONNECTED, + "peer disconnected without gracefully ending TLS session"); + } +#endif + char message[1024]; ERR_error_string_n(error, message, sizeof(message)); lines.add(kj::heapString(message)); } kj::String message = kj::strArray(lines, "\n"); - KJ_FAIL_ASSERT("OpenSSL error", message); + return KJ_EXCEPTION(FAILED, "OpenSSL error", message); +} + +KJ_NORETURN(void throwOpensslError()); +void throwOpensslError() { + // Call when an OpenSSL function returns an error code to convert that into an exception and + // throw it. + + kj::throwFatalException(getOpensslError()); } #if OPENSSL_VERSION_NUMBER < 0x10100000L && !defined(OPENSSL_IS_BORINGSSL) @@ -100,6 +114,37 @@ inline void ensureOpenSslInitialized() { } #endif +bool isIpAddress(kj::StringPtr addr) { + bool isPossiblyIp6 = true; + bool isPossiblyIp4 = true; + uint colonCount = 0; + uint dotCount = 0; + for (auto c: addr) { + if (c == ':') { + isPossiblyIp4 = false; + ++colonCount; + } else if (c == '.') { + isPossiblyIp6 = false; + ++dotCount; + } else if ('0' <= c && c <= '9') { + // Digit is valid for ipv4 or ipv6. + } else if (('a' <= c && c <= 'f') || ('A' <= c && c <= 'F')) { + // Hex digit could be ipv6 but not ipv4. + isPossiblyIp4 = false; + } else { + // Nope. + return false; + } + } + + // An IPv4 address has 3 dots. (Yes, I'm aware that technically IPv4 addresses can be formatted + // with fewer dots, but it's not clear that we actually want to support TLS authentication of + // non-canonical address formats, so for now I'm not. File a bug if you care.) An IPv6 address + // has at least 2 and as many as 7 colons. + return (isPossiblyIp4 && dotCount == 3) + || (isPossiblyIp6 && colonCount >= 2 && colonCount <= 7); +} + } // namespace // ======================================================================================= @@ -138,28 +183,42 @@ public: kj::Promise connect(kj::StringPtr expectedServerHostname) { if (!SSL_set_tlsext_host_name(ssl, expectedServerHostname.cStr())) { - throwOpensslError(); + return getOpensslError(); } X509_VERIFY_PARAM* verify = SSL_get0_param(ssl); if (verify == nullptr) { - throwOpensslError(); + return getOpensslError(); } - if (X509_VERIFY_PARAM_set1_host( - verify, expectedServerHostname.cStr(), expectedServerHostname.size()) <= 0) { - throwOpensslError(); + if (isIpAddress(expectedServerHostname)) { + if (X509_VERIFY_PARAM_set1_ip_asc(verify, expectedServerHostname.cStr()) <= 0) { + return getOpensslError(); + } + } else { + if (X509_VERIFY_PARAM_set1_host( + verify, expectedServerHostname.cStr(), expectedServerHostname.size()) <= 0) { + return getOpensslError(); + } } + // As of OpenSSL 1.1.0, X509_V_FLAG_TRUSTED_FIRST is on by default. Turning it on for older + // versions -- as well as certain OpenSSL-compatible libraries -- fixes the problem described + // here: https://community.letsencrypt.org/t/openssl-client-compatibility-changes-for-let-s-encrypt-certificates/143816 + // + // Otherwise, certificates issued by Let's Encrypt won't work as of September 30, 2021: + // https://letsencrypt.org/docs/dst-root-ca-x3-expiration-september-2021/ + X509_VERIFY_PARAM_set_flags(verify, X509_V_FLAG_TRUSTED_FIRST); + return sslCall([this]() { return SSL_connect(ssl); }).then([this](size_t) { X509* cert = SSL_get_peer_certificate(ssl); - KJ_REQUIRE(cert != nullptr, "TLS peer provided no certificate"); + KJ_REQUIRE(cert != nullptr, "TLS peer provided no certificate") { return; } X509_free(cert); auto result = SSL_get_verify_result(ssl); if (result != X509_V_OK) { const char* reason = X509_verify_cert_error_string(result); - KJ_FAIL_REQUIRE("TLS peer's certificate is not trusted", reason); + KJ_FAIL_REQUIRE("TLS peer's certificate is not trusted", reason) { break; } } }); } @@ -208,7 +267,7 @@ public: void shutdownWrite() override { KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()"); - // TODO(0.10): shutdownWrite() is problematic because it doesn't return a promise. It was + // TODO(2.0): shutdownWrite() is problematic because it doesn't return a promise. It was // designed to assume that it would only be called after all writes are finished and that // there was no reason to block at that point, but SSL sessions don't fit this since they // actually have to send a shutdown message. @@ -248,7 +307,6 @@ private: kj::AsyncIoStream& inner; kj::Own ownInner; - bool disconnected = false; kj::Maybe> shutdownTask; ReadyInputStreamWrapper readBuffer; @@ -256,8 +314,6 @@ private: kj::Promise tryReadInternal( void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyDone) { - if (disconnected) return alreadyDone; - return sslCall([this,buffer,maxBytes]() { return SSL_read(ssl, buffer, maxBytes); }) .then([this,buffer,minBytes,maxBytes,alreadyDone](size_t n) -> kj::Promise { if (n >= minBytes || n == 0) { @@ -299,8 +355,6 @@ private: template kj::Promise sslCall(Func&& func) { - if (disconnected) return size_t(0); - auto result = func(); if (result > 0) { @@ -309,20 +363,22 @@ private: int error = SSL_get_error(ssl, result); switch (error) { case SSL_ERROR_ZERO_RETURN: - disconnected = true; - return size_t(0); + return constPromise(); case SSL_ERROR_WANT_READ: - return readBuffer.whenReady().then(kj::mvCapture(func, - [this](Func&& func) mutable { return sslCall(kj::fwd(func)); })); + return readBuffer.whenReady().then( + [this,func=kj::mv(func)]() mutable { return sslCall(kj::fwd(func)); }); case SSL_ERROR_WANT_WRITE: - return writeBuffer.whenReady().then(kj::mvCapture(func, - [this](Func&& func) mutable { return sslCall(kj::fwd(func)); })); + return writeBuffer.whenReady().then( + [this,func=kj::mv(func)]() mutable { return sslCall(kj::fwd(func)); }); case SSL_ERROR_SSL: - throwOpensslError(); + return getOpensslError(); case SSL_ERROR_SYSCALL: if (result == 0) { - disconnected = true; - return size_t(0); + // OpenSSL pre-3.0 reports unexpected disconnects this way. Note that 3.0+ report it + // as SSL_ERROR_SSL with the reason SSL_R_UNEXPECTED_EOF_WHILE_READING, which is + // handled in throwOpensslError(). + return KJ_EXCEPTION(DISCONNECTED, + "peer disconnected without gracefully ending TLS session"); } else { // According to documentation we shouldn't get here, because our BIO never returns an // "error". But in practice we do get here sometimes when the peer disconnects @@ -359,12 +415,20 @@ private: static long bioCtrl(BIO* b, int cmd, long num, void* ptr) { switch (cmd) { + case BIO_CTRL_EOF: + return reinterpret_cast(BIO_get_data(b))->readBuffer.isAtEnd(); case BIO_CTRL_FLUSH: return 1; case BIO_CTRL_PUSH: case BIO_CTRL_POP: // Informational? return 0; +#ifdef BIO_CTRL_GET_KTLS_SEND + case BIO_CTRL_GET_KTLS_SEND: + case BIO_CTRL_GET_KTLS_RECV: + // TODO(someday): Support kTLS if the underlying stream is a raw socket. + return 0; +#endif default: KJ_LOG(WARNING, "unimplemented bio_ctrl", cmd); return 0; @@ -419,18 +483,20 @@ private: class TlsConnectionReceiver final: public ConnectionReceiver, public TaskSet::ErrorHandler { public: - TlsConnectionReceiver(TlsContext &tls, Own inner) + TlsConnectionReceiver( + TlsContext &tls, Own inner, + kj::Maybe acceptErrorHandler) : tls(tls), inner(kj::mv(inner)), acceptLoopTask(acceptLoop().eagerlyEvaluate([this](Exception &&e) { onAcceptFailure(kj::mv(e)); })), + acceptErrorHandler(kj::mv(acceptErrorHandler)), tasks(*this) {} void taskFailed(Exception&& e) override { - // TODO(someday): SSL connection failures may be a fact of normal operation but they may also - // be important diagnostic information. We should allow for an error handler to be passed in so - // that network issues that affect TLS can be more discoverable from the server side. - if (e.getType() != Exception::Type::DISCONNECTED) { + KJ_IF_MAYBE(handler, acceptErrorHandler){ + handler->operator()(kj::mv(e)); + } else if (e.getType() != Exception::Type::DISCONNECTED) { KJ_LOG(ERROR, "error accepting tls connection", kj::mv(e)); } }; @@ -504,6 +570,7 @@ private: Promise acceptLoopTask; ProducerConsumerQueue queue; + kj::Maybe acceptErrorHandler; TaskSet tasks; Maybe maybeInnerException; @@ -520,10 +587,10 @@ public: // So, we make some copies here. auto& tlsRef = tls; auto hostnameCopy = kj::str(hostname); - return inner->connect().then(kj::mvCapture(hostnameCopy, - [&tlsRef](kj::String&& hostname, Own&& stream) { + return inner->connect().then( + [&tlsRef,hostname=kj::mv(hostnameCopy)](Own&& stream) { return tlsRef.wrapClient(kj::mv(stream), hostname); - })); + }); } Promise connectAuthenticated() override { @@ -563,18 +630,58 @@ public: : tls(tls), inner(*inner), ownInner(kj::mv(inner)) {} Promise> parseAddress(StringPtr addr, uint portHint) override { + // We want to parse the hostname or IP address out of `addr`. This is a bit complicated as + // KJ's default network implementation has a fairly featureful grammar for these things. + // In particular, we cannot just split on ':' because the address might be IPv6. + kj::String hostname; - KJ_IF_MAYBE(pos, addr.findFirst(':')) { - hostname = kj::heapString(addr.slice(0, *pos)); + + if (addr.startsWith("[")) { + // IPv6, like "[1234:5678::abcd]:123". Take the part between the brackets. + KJ_IF_MAYBE(pos, addr.findFirst(']')) { + hostname = kj::str(addr.slice(1, *pos)); + } else { + // Uhh??? Just take the whole thing, cert will fail later. + hostname = kj::heapString(addr); + } + } else if (addr.startsWith("unix:") || addr.startsWith("unix-abstract:")) { + // Unfortunately, `unix:123` is ambiguous (maybe there is a host named "unix"?), but the + // default KJ network implementation will interpret it as a Unix domain socket address. + // We don't want TLS to then try to authenticate that as a host named "unix". + KJ_FAIL_REQUIRE("can't authenticate Unix domain socket with TLS", addr); } else { - hostname = kj::heapString(addr); + uint colons = 0; + for (auto c: addr) { + if (c == ':') { + ++colons; + } + } + + if (colons >= 2) { + // Must be an IPv6 address. If it had a port, it would have been wrapped in []. So don't + // strip the port. + hostname = kj::heapString(addr); + } else { + // Assume host:port or ipv4:port. This is a shaky assumption, as the above hacks + // demonstrate. + // + // In theory it might make sense to extend the NetworkAddress interface so that it can tell + // us what the actual parser decided the hostname is. However, when I tried this it proved + // rather cumbersome and actually broke code in the Workers Runtime that does complicated + // stacking of kj::Network implementations. + KJ_IF_MAYBE(pos, addr.findFirst(':')) { + hostname = kj::heapString(addr.slice(0, *pos)); + } else { + hostname = kj::heapString(addr); + } + } } return inner.parseAddress(addr, portHint) - .then(kj::mvCapture(hostname, [this](kj::String&& hostname, kj::Own&& addr) + .then([this, hostname=kj::mv(hostname)](kj::Own&& addr) mutable -> kj::Own { return kj::heap(tls, kj::mv(hostname), kj::mv(addr)); - })); + }); } Own getSockaddr(const void* sockaddr, uint len) override { @@ -670,6 +777,13 @@ TlsContext::TlsContext(Options options) { if (options.minVersion > TlsVersion::TLS_1_2) { optionFlags |= SSL_OP_NO_TLSv1_2; } + if (options.minVersion > TlsVersion::TLS_1_3) { +#ifdef SSL_OP_NO_TLSv1_3 + optionFlags |= SSL_OP_NO_TLSv1_3; +#else + KJ_FAIL_REQUIRE("OpenSSL headers don't support TLS 1.3"); +#endif + } SSL_CTX_set_options(ctx, optionFlags); // note: never fails; returns new options bitmask // honor options.cipherList @@ -712,6 +826,8 @@ TlsContext::TlsContext(Options options) { this->acceptTimeout = *timeout; } + this->acceptErrorHandler = kj::mv(options.acceptErrorHandler); + this->ctx = ctx; } @@ -766,22 +882,24 @@ kj::Promise> TlsContext::wrapClient( kj::Own stream, kj::StringPtr expectedServerHostname) { auto conn = kj::heap(kj::mv(stream), reinterpret_cast(ctx)); auto promise = conn->connect(expectedServerHostname); - return promise.then(kj::mvCapture(conn, [](kj::Own conn) + return promise.then([conn=kj::mv(conn)]() mutable -> kj::Own { return kj::mv(conn); - })); + }); } kj::Promise> TlsContext::wrapServer(kj::Own stream) { auto conn = kj::heap(kj::mv(stream), reinterpret_cast(ctx)); auto promise = conn->accept(); KJ_IF_MAYBE(timeout, acceptTimeout) { - promise = KJ_REQUIRE_NONNULL(timer).timeoutAfter(*timeout, kj::mv(promise)); + promise = KJ_REQUIRE_NONNULL(timer).afterDelay(*timeout).then([]() -> kj::Promise { + return KJ_EXCEPTION(DISCONNECTED, "timed out waiting for client during TLS handshake"); + }).exclusiveJoin(kj::mv(promise)); } - return promise.then(kj::mvCapture(conn, [](kj::Own conn) + return promise.then([conn=kj::mv(conn)]() mutable -> kj::Own { return kj::mv(conn); - })); + }); } kj::Promise TlsContext::wrapClient( @@ -798,7 +916,9 @@ kj::Promise TlsContext::wrapServer(kj::AuthenticatedStr auto conn = kj::heap(kj::mv(stream.stream), reinterpret_cast(ctx)); auto promise = conn->accept(); KJ_IF_MAYBE(timeout, acceptTimeout) { - promise = KJ_REQUIRE_NONNULL(timer).timeoutAfter(*timeout, kj::mv(promise)); + promise = KJ_REQUIRE_NONNULL(timer).afterDelay(*timeout).then([]() -> kj::Promise { + return KJ_EXCEPTION(DISCONNECTED, "timed out waiting for client during TLS handshake"); + }).exclusiveJoin(kj::mv(promise)); } return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable { auto id = conn->getIdentity(kj::mv(innerId)); @@ -807,7 +927,15 @@ kj::Promise TlsContext::wrapServer(kj::AuthenticatedStr } kj::Own TlsContext::wrapPort(kj::Own port) { - return kj::heap(*this, kj::mv(port)); + auto handler = acceptErrorHandler.map([](TlsErrorHandler& handler) { + return handler.reference(); + }); + return kj::heap(*this, kj::mv(port), kj::mv(handler)); +} + +kj::Own TlsContext::wrapAddress( + kj::Own address, kj::StringPtr expectedServerHostname) { + return kj::heap(*this, kj::str(expectedServerHostname), kj::mv(address)); } kj::Own TlsContext::wrapNetwork(kj::Network& network) { @@ -885,7 +1013,7 @@ TlsCertificate::TlsCertificate(kj::ArrayPtr> asn1 for (auto i: kj::indices(asn1)) { auto p = asn1[i].begin(); - // "_AUX" apparently refers to some auxilliary information that can be appended to the + // "_AUX" apparently refers to some auxiliary information that can be appended to the // certificate, but should only be trusted for your own certificate, not the whole chain?? // I don't really know, I'm just cargo-culting. chain[i] = i == 0 ? d2i_X509_AUX(nullptr, &p, asn1[i].size()) @@ -913,7 +1041,7 @@ TlsCertificate::TlsCertificate(kj::StringPtr pem) { KJ_DEFER(BIO_free(bio)); for (auto i: kj::indices(chain)) { - // "_AUX" apparently refers to some auxilliary information that can be appended to the + // "_AUX" apparently refers to some auxiliary information that can be appended to the // certificate, but should only be trusted for your own certificate, not the whole chain?? // I don't really know, I'm just cargo-culting. chain[i] = i == 0 ? PEM_read_bio_X509_AUX(bio, nullptr, nullptr, nullptr) diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls.h b/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls.h index b901e1fa781..f78a23c9994 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/tls.h @@ -30,6 +30,8 @@ #include +KJ_BEGIN_HEADER + namespace kj { class TlsPrivateKey; @@ -40,21 +42,31 @@ class TlsConnection; enum class TlsVersion { SSL_3, // avoid; cryptographically broken - TLS_1_0, - TLS_1_1, - TLS_1_2 + TLS_1_0, // avoid; cryptographically weak + TLS_1_1, // avoid; cryptographically weak + TLS_1_2, + TLS_1_3 }; -class TlsContext { +using TlsErrorHandler = kj::Function; +// Use a simple kj::Function for handling errors during parallel accept(). + +class TlsContext: public kj::SecureNetworkWrapper { // TLS system. Allocate one of these, configure it with the proper keys and certificates (or // use the defaults), and then use it to wrap the standard KJ network interfaces in // implementations that transparently use TLS. public: + struct Options { Options(); // Initializes all values to reasonable defaults. + KJ_DISALLOW_COPY(Options); + Options(Options&&) = default; + Options& operator=(Options&&) = default; + // Options is a move-only value type. + bool useSystemTrustStore; // Whether or not to trust the system's default trust store. Default: true. @@ -76,7 +88,7 @@ class TlsContext { kj::StringPtr cipherList; // OpenSSL cipher list string. The default is a curated list designed to be compatible with - // almost all software in curent use (specifically, based on Mozilla's "intermediate" + // almost all software in current use (specifically, based on Mozilla's "intermediate" // recommendations). The defaults will change in future versions of this library to account // for the latest cryptanalysis. // @@ -97,11 +109,14 @@ class TlsContext { kj::Maybe acceptTimeout; // Timeout applied to accepting a new TLS connection. `timer` is required if this is set. + + kj::Maybe acceptErrorHandler; + // Error handler used for TLS accept errors. }; TlsContext(Options options = Options()); ~TlsContext() noexcept(false); - KJ_DISALLOW_COPY(TlsContext); + KJ_DISALLOW_COPY_AND_MOVE(TlsContext); kj::Promise> wrapServer(kj::Own stream); // Upgrade a regular network stream to TLS and begin the initial handshake as the server. The @@ -129,6 +144,12 @@ class TlsContext { // Upgrade a ConnectionReceiver to one that automatically upgrades all accepted connections to // TLS (acting as the server). + kj::Own wrapAddress( + kj::Own address, kj::StringPtr expectedServerHostname); + // Upgrade a NetworkAddress to one that automatically upgrades all connections to TLS, acting + // as the client when `connect()` is called or the server if `listen()` is called. + // `connect()` will athenticate the server as `expectedServerHostname`. + kj::Own wrapNetwork(kj::Network& network); // Upgrade a Network to one that automatically upgrades all connections to TLS. The network will // only accept addresses of the form "hostname" and "hostname:port" (it does not accept raw IP @@ -138,6 +159,7 @@ class TlsContext { void* ctx; // actually type SSL_CTX, but we don't want to #include the OpenSSL headers here kj::Maybe timer; kj::Maybe acceptTimeout; + kj::Maybe acceptErrorHandler; struct SniCallback; }; @@ -246,7 +268,7 @@ class TlsSniCallback { class TlsPeerIdentity final: public kj::PeerIdentity { public: - KJ_DISALLOW_COPY(TlsPeerIdentity); + KJ_DISALLOW_COPY_AND_MOVE(TlsPeerIdentity); ~TlsPeerIdentity() noexcept(false); kj::String toString() override; @@ -283,3 +305,5 @@ class TlsPeerIdentity final: public kj::PeerIdentity { }; } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/compat/url.h b/libs/EXTERNAL/capnproto/c++/src/kj/compat/url.h index 2001adf4331..6e38d230612 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/compat/url.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/compat/url.h @@ -25,6 +25,8 @@ #include #include +KJ_BEGIN_HEADER + namespace kj { struct UrlOptions { @@ -145,3 +147,5 @@ struct Url { }; } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/debug-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/debug-test.c++ index 505ae3e785f..3c65b5218b3 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/debug-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/debug-test.c++ @@ -40,11 +40,6 @@ #include #endif -#if _MSC_VER && !defined(__clang__) -#pragma warning(disable: 4996) -// Warns that sprintf() is buffer-overrunny. Yeah, I know, it's cool. -#endif - namespace kj { namespace _ { // private namespace { @@ -203,7 +198,7 @@ std::string fileLine(std::string file, int line) { file += ':'; char buffer[32]; - sprintf(buffer, "%d", line); + snprintf(buffer, sizeof(buffer), "%d", line); file += buffer; return file; } diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/debug.h b/libs/EXTERNAL/capnproto/c++/src/kj/debug.h index 25659c932fb..9f8459b1cee 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/debug.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/debug.h @@ -67,6 +67,13 @@ // * `KJ_REQUIRE(condition, ...)`: Like `KJ_ASSERT` but used to check preconditions -- e.g. to // validate parameters passed from a caller. A failure indicates that the caller is buggy. // +// * `KJ_ASSUME(condition, ...)`: Like `KJ_ASSERT`, but in release mode (if KJ_DEBUG is not +// defined; see below) instead warrants to the compiler that the condition can be assumed to +// hold, allowing it to optimize accordingly. This can result in undefined behavior, so use +// this macro *only* if you can prove to your satisfaction that the condition is guaranteed by +// surrounding code, and if the condition failing to hold would in any case result in undefined +// behavior in its dependencies. +// // * `KJ_SYSCALL(code, ...)`: Executes `code` assuming it makes a system call. A negative result // is considered an error, with error code reported via `errno`. EINTR is handled by retrying. // Other errors are handled by throwing an exception. If you need to examine the return code, @@ -98,11 +105,12 @@ // omits the first parameter and behaves like it was `false`. `FAIL_SYSCALL` and // `FAIL_RECOVERABLE_SYSCALL` take a string and an OS error number as the first two parameters. // The string should be the name of the failed system call. -// * For every macro `FOO` above, there is a `DFOO` version (or `RECOVERABLE_DFOO`) which is only -// executed in debug mode, i.e. when KJ_DEBUG is defined. KJ_DEBUG is defined automatically -// by common.h when compiling without optimization (unless NDEBUG is defined), but you can also -// define it explicitly (e.g. -DKJ_DEBUG). Generally, production builds should NOT use KJ_DEBUG -// as it may enable expensive checks that are unlikely to fail. +// * For every macro `FOO` above except `ASSUME`, there is a `DFOO` version (or +// `RECOVERABLE_DFOO`) which is only executed in debug mode, i.e. when KJ_DEBUG is defined. +// KJ_DEBUG is defined automatically by common.h when compiling without optimization (unless +// NDEBUG is defined), but you can also define it explicitly (e.g. -DKJ_DEBUG). Generally, +// production builds should NOT use KJ_DEBUG as it may enable expensive checks that are unlikely +// to fail. #pragma once @@ -114,7 +122,7 @@ KJ_BEGIN_HEADER namespace kj { -#if _MSC_VER && !defined(__clang__) +#if KJ_MSVC_TRADITIONAL_CPP // MSVC does __VA_ARGS__ differently from GCC: // - A trailing comma before an empty __VA_ARGS__ is removed automatically, whereas GCC wants // you to request this behavior with "##__VA_ARGS__". @@ -274,6 +282,20 @@ namespace kj { ::kj::_::Debug::ContextImpl \ KJ_UNIQUE_NAME(_kjContext)(KJ_UNIQUE_NAME(_kjContextFunc)) +#if _MSC_VER && !defined(__clang__) + +#define KJ_REQUIRE_NONNULL(value, ...) \ + (*([&] { \ + auto _kj_result = ::kj::_::readMaybe(value); \ + if (KJ_UNLIKELY(!_kj_result)) { \ + ::kj::_::Debug::Fault(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \ + #value " != nullptr", #__VA_ARGS__, ##__VA_ARGS__).fatal(); \ + } \ + return _kj_result; \ + }())) + +#else + #define KJ_REQUIRE_NONNULL(value, ...) \ (*({ \ auto _kj_result = ::kj::_::readMaybe(value); \ @@ -284,6 +306,8 @@ namespace kj { kj::mv(_kj_result); \ })) +#endif + #define KJ_EXCEPTION(type, ...) \ ::kj::Exception(::kj::Exception::Type::type, __FILE__, __LINE__, \ ::kj::_::Debug::makeDescription(#__VA_ARGS__, ##__VA_ARGS__)) @@ -342,10 +366,21 @@ namespace kj { #define KJ_DLOG KJ_LOG #define KJ_DASSERT KJ_ASSERT #define KJ_DREQUIRE KJ_REQUIRE +#define KJ_ASSUME KJ_ASSERT #else #define KJ_DLOG(...) do {} while (false) #define KJ_DASSERT(...) do {} while (false) #define KJ_DREQUIRE(...) do {} while (false) +#if defined(__GNUC__) +#define KJ_ASSUME(cond, ...) do { if (cond) {} else __builtin_unreachable(); } while (false) +#elif defined(__clang__) +#define KJ_ASSUME(cond, ...) __builtin_assume(cond) +#elif defined(_MSC_VER) +#define KJ_ASSUME(cond, ...) __assume(cond) +#else +#define KJ_ASSUME(...) do {} while (false) +#endif + #endif namespace _ { // private @@ -432,7 +467,7 @@ class Debug { class Context: public ExceptionCallback { public: Context(); - KJ_DISALLOW_COPY(Context); + KJ_DISALLOW_COPY_AND_MOVE(Context); virtual ~Context() noexcept(false); struct Value { @@ -462,7 +497,7 @@ class Debug { class ContextImpl: public Context { public: inline ContextImpl(Func& func): func(func) {} - KJ_DISALLOW_COPY(ContextImpl); + KJ_DISALLOW_COPY_AND_MOVE(ContextImpl); Value evaluate() override { return func(); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/encoding-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/encoding-test.c++ index 50b1223dda5..7c02b944638 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/encoding-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/encoding-test.c++ @@ -28,6 +28,10 @@ namespace { CappedArray hex(byte i) { return kj::hex((uint8_t )i); } CappedArray hex(char i) { return kj::hex((uint8_t )i); } +#if __cpp_char8_t +[[maybe_unused]] +CappedArray hex(char8_t i) { return kj::hex((uint8_t )i); } +#endif CappedArray hex(char16_t i) { return kj::hex((uint16_t)i); } CappedArray hex(char32_t i) { return kj::hex((uint32_t)i); } CappedArray hex(wchar_t i) { return kj::hex((uint32_t)i); } @@ -58,7 +62,7 @@ void expectRes(EncodingResult result, expectResImpl(kj::mv(result), arrayPtr(expected, s - 1), errors); } -#if __cplusplus >= 202000L +#if __cpp_char8_t template void expectRes(EncodingResult result, const char8_t (&expected)[s], diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/encoding.h b/libs/EXTERNAL/capnproto/c++/src/kj/encoding.h index d61ee473b52..293ecaf1daf 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/encoding.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/encoding.h @@ -372,7 +372,7 @@ EncodingResult> decodeBase64(const char (&text)[s]) { return decodeBase64(arrayPtr(text, s - 1)); } -#if __cplusplus >= 202000L +#if __cpp_char8_t template inline EncodingResult> encodeUtf16(const char8_t (&text)[s], bool nulTerminate=false) { return encodeUtf16(arrayPtr(reinterpret_cast(text), s - 1), nulTerminate); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/exception-override-symbolizer-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/exception-override-symbolizer-test.c++ new file mode 100644 index 00000000000..bb9de024966 --- /dev/null +++ b/libs/EXTERNAL/capnproto/c++/src/kj/exception-override-symbolizer-test.c++ @@ -0,0 +1,49 @@ +// Copyright (c) 2022 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if __GNUC__ && !_WIN32 + +#include "debug.h" +#include +#include "kj/common.h" +#include "kj/array.h" +#include +#include +#include + +namespace kj { + +// override weak symbol +String stringifyStackTrace(ArrayPtr trace) { + return kj::str("\n\nTEST_SYMBOLIZER\n\n"); +} + +namespace { + +KJ_TEST("getStackTrace() uses symbolizer override") { + auto trace = getStackTrace(); + KJ_ASSERT(strstr(trace.cStr(), "TEST_SYMBOLIZER") != nullptr, trace); +} + +} // namespace +} // namespace kj + +#endif diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/exception-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/exception-test.c++ index 2cc37d60bd3..50054ab24c2 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/exception-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/exception-test.c++ @@ -132,10 +132,16 @@ TEST(Exception, UnwindDetector) { } #endif +#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) || \ + KJ_HAS_COMPILER_FEATURE(address_sanitizer) || \ + defined(__SANITIZE_ADDRESS__) +// The implementation skips this check in these cases. +#else #if !__MINGW32__ // Inexplicably crashes when exception is thrown from constructor. TEST(Exception, ExceptionCallbackMustBeOnStack) { KJ_EXPECT_THROW_MESSAGE("must be allocated on the stack", new ExceptionCallback); } +#endif #endif // !__MINGW32__ #if !KJ_NO_EXCEPTIONS @@ -196,7 +202,7 @@ KJ_TEST("getStackTrace() returns correct line number, not line + 1") { // contain the right one. // 2) This test only detects the problem if the call instruction to testStackTrace() is the // *last* instruction attributed to its line of code. Whether or not this is true seems to be - // dependent on obscure complier behavior. For example, below, it could only be the case if + // dependent on obscure compiler behavior. For example, below, it could only be the case if // RVO is applied -- but in my testing, RVO does seem to be applied here. I tried several // variations involving passing via an output parameter or a global variable rather than // returning, but found some variations detected the problem and others didn't, essentially diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/exception.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/exception.c++ index c2dda506ccf..75e3179db4e 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/exception.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/exception.c++ @@ -254,6 +254,13 @@ ArrayPtr getStackTrace(ArrayPtr space, uint ignoreCount) { #endif } +#if (__GNUC__ && !_WIN32) || __clang__ +// Allow dependents to override the implementation of stack symbolication by making it a weak +// symbol. We prefer weak symbols over some sort of callback registration mechanism becasue this +// allows an alternate symbolication library to be easily linked into tests without changing the +// code of the test. +__attribute__((weak)) +#endif String stringifyStackTrace(ArrayPtr trace) { if (trace.size() == 0) return nullptr; if (getExceptionCallback().stackTraceMode() != ExceptionCallback::StackTraceMode::FULL) { @@ -278,7 +285,8 @@ String stringifyStackTrace(ArrayPtr trace) { IMAGEHLP_LINE64 lineInfo; memset(&lineInfo, 0, sizeof(lineInfo)); lineInfo.SizeOfStruct = sizeof(lineInfo); - if (dbghelp.symGetLineFromAddr64(process, reinterpret_cast(trace[i]), NULL, &lineInfo)) { + DWORD displacement; + if (dbghelp.symGetLineFromAddr64(process, reinterpret_cast(trace[i]), &displacement, &lineInfo)) { lines[i] = kj::str('\n', lineInfo.FileName, ':', lineInfo.LineNumber); } } @@ -714,6 +722,29 @@ retry: return filename; } +void resetCrashHandlers() { +#ifndef _WIN32 + struct sigaction action; + memset(&action, 0, sizeof(action)); + + action.sa_handler = SIG_DFL; + KJ_SYSCALL(sigaction(SIGSEGV, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGBUS, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGFPE, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGABRT, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGILL, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGSYS, &action, nullptr)); + +#ifdef KJ_DEBUG + KJ_SYSCALL(sigaction(SIGINT, &action, nullptr)); +#endif +#endif + +#if !KJ_NO_EXCEPTIONS + std::set_terminate(nullptr); +#endif +} + StringPtr KJ_STRINGIFY(Exception::Type type) { static const char* TYPE_STRINGS[] = { "failed", @@ -805,6 +836,15 @@ void Exception::wrapContext(const char* file, int line, String&& description) { } void Exception::extendTrace(uint ignoreCount, uint limit) { + if (isFullTrace) { + // Awkward: extendTrace() was called twice without truncating in between. This should probably + // be an error, but historically we didn't check for this so I'm hesitant to make it an error + // now. We shouldn't actually extend the trace, though, as our current trace is presumably + // rooted in main() and it'd be weird to append frames "above" that. + // TODO(cleanup): Abort here and see what breaks? + return; + } + KJ_STACK_ARRAY(void*, newTraceSpace, kj::min(kj::size(trace), limit) + ignoreCount + 1, sizeof(trace)/sizeof(trace[0]) + 8, 128); @@ -816,10 +856,26 @@ void Exception::extendTrace(uint ignoreCount, uint limit) { // Copy the rest into our trace. memcpy(trace + traceCount, newTrace.begin(), newTrace.asBytes().size()); traceCount += newTrace.size(); + isFullTrace = true; } } void Exception::truncateCommonTrace() { + if (isFullTrace) { + // We're truncating the common portion of the full trace, turning it back into a limited + // trace. + isFullTrace = false; + } else { + // If the trace was never extended in the first place, trying to truncate it is at best a waste + // of time and at worst might remove information for no reason. So, don't. + // + // This comes up in particular in coroutines, when the exception originated from a co_awaited + // promise. In that case we manually add the one relevant frame to the trace, rather than + // call extendTrace() just to have to truncate most of it again a moment later in the + // unhandled_exception() callback. + return; + } + if (traceCount > 0) { // Create a "reference" stack trace that is a little bit deeper than the one in the exception. void* refTraceSpace[sizeof(this->trace) / sizeof(this->trace[0]) + 4]; @@ -857,6 +913,9 @@ void Exception::truncateCommonTrace() { } void Exception::addTrace(void* ptr) { + // TODO(cleanup): Abort here if isFullTrace is true, and see what breaks. This method only makes + // sense to call on partial traces. + if (traceCount < kj::size(trace)) { trace[traceCount++] = ptr; } @@ -969,14 +1028,21 @@ KJ_THREADLOCAL_PTR(ExceptionCallback) threadLocalCallback = nullptr; } // namespace -ExceptionCallback::ExceptionCallback(): next(getExceptionCallback()) { +void requireOnStack(void* ptr, kj::StringPtr description) { +#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) || \ + KJ_HAS_COMPILER_FEATURE(address_sanitizer) || \ + defined(__SANITIZE_ADDRESS__) + // When using libfuzzer or ASAN, this sanity check may spurriously fail, so skip it. +#else char stackVar; -#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION - ptrdiff_t offset = reinterpret_cast(this) - &stackVar; - KJ_ASSERT(offset < 65536 && offset > -65536, - "ExceptionCallback must be allocated on the stack."); + ptrdiff_t offset = reinterpret_cast(ptr) - &stackVar; + KJ_REQUIRE(offset < 65536 && offset > -65536, + kj::str(description)); #endif +} +ExceptionCallback::ExceptionCallback(): next(getExceptionCallback()) { + requireOnStack(this, "ExceptionCallback must be allocated on the stack."); threadLocalCallback = this; } @@ -1124,13 +1190,13 @@ ExceptionCallback& getExceptionCallback() { } void throwFatalException(kj::Exception&& exception, uint ignoreCount) { - exception.extendTrace(ignoreCount + 1); + if (ignoreCount != (uint)kj::maxValue) exception.extendTrace(ignoreCount + 1); getExceptionCallback().onFatalException(kj::mv(exception)); abort(); } void throwRecoverableException(kj::Exception&& exception, uint ignoreCount) { - exception.extendTrace(ignoreCount + 1); + if (ignoreCount != (uint)kj::maxValue) exception.extendTrace(ignoreCount + 1); getExceptionCallback().onRecoverableException(kj::mv(exception)); } @@ -1138,7 +1204,7 @@ void throwRecoverableException(kj::Exception&& exception, uint ignoreCount) { namespace _ { // private -#if __cplusplus >= 201703L +#if KJ_CPP_STD >= 201703L uint uncaughtExceptionCount() { return std::uncaught_exceptions(); @@ -1220,12 +1286,14 @@ bool UnwindDetector::isUnwinding() const { return _::uncaughtExceptionCount() > uncaughtCount; } -void UnwindDetector::catchExceptionsAsSecondaryFaults(_::Runnable& runnable) const { +#if !KJ_NO_EXCEPTIONS +void UnwindDetector::catchThrownExceptionAsSecondaryFault() const { // TODO(someday): Attach the secondary exception to whatever primary exception is causing // the unwind. For now we just drop it on the floor as this is probably fine most of the // time. - runCatchingExceptions(runnable); + getCaughtExceptionAsKj(); } +#endif #if __GNUC__ && !KJ_NO_RTTI static kj::String demangleTypeName(const char* name) { @@ -1294,6 +1362,8 @@ kj::ArrayPtr computeRelativeTrace( return bestMatch; } +#if KJ_NO_EXCEPTIONS + namespace _ { // private class RecoverableExceptionCatcher: public ExceptionCallback { @@ -1315,17 +1385,21 @@ public: }; Maybe runCatchingExceptions(Runnable& runnable) { -#if KJ_NO_EXCEPTIONS RecoverableExceptionCatcher catcher; runnable.run(); KJ_IF_MAYBE(e, catcher.caught) { e->truncateCommonTrace(); } return mv(catcher.caught); -#else +} + +} // namespace _ (private) + +#else // KJ_NO_EXCEPTIONS + +kj::Exception getCaughtExceptionAsKj() { try { - runnable.run(); - return nullptr; + throw; } catch (Exception& e) { e.truncateCommonTrace(); return kj::mv(e); @@ -1347,9 +1421,7 @@ Maybe runCatchingExceptions(Runnable& runnable) { return Exception(Exception::Type::FAILED, "(unknown)", -1, str("unknown non-KJ exception")); #endif } -#endif } - -} // namespace _ (private) +#endif // !KJ_NO_EXCEPTIONS } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/exception.h b/libs/EXTERNAL/capnproto/c++/src/kj/exception.h index 8c20b1b39ea..be90163f933 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/exception.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/exception.h @@ -80,6 +80,8 @@ class Exception { StringPtr getDescription() const { return description; } ArrayPtr getStackTrace() const { return arrayPtr(trace, traceCount); } + void setDescription(kj::String&& desc) { description = kj::mv(desc); } + StringPtr getRemoteTrace() const { return remoteTrace; } void setRemoteTrace(kj::String&& value) { remoteTrace = kj::mv(value); } // Additional stack trace data originating from a remote server. If present, then @@ -142,6 +144,19 @@ class Exception { void* trace[32]; uint traceCount; + bool isFullTrace = false; + // Is `trace` a full trace to the top of the stack (or as close as we could get before we ran + // out of space)? If this is false, then `trace` is instead a partial trace covering just the + // frames between where the exception was thrown and where it was caught. + // + // extendTrace() transitions this to true, and truncateCommonTrace() changes it back to false. + // + // In theory, an exception should only hold a full trace when it is in the process of being + // thrown via the C++ exception handling mechanism -- extendTrace() is called before the throw + // and truncateCommonTrace() after it is caught. Note that when exceptions propagate through + // async promises, the trace is extended one frame at a time instead, so isFullTrace should + // remain false. + friend class ExceptionImpl; }; @@ -185,7 +200,7 @@ class ExceptionCallback { public: ExceptionCallback(); - KJ_DISALLOW_COPY(ExceptionCallback); + KJ_DISALLOW_COPY_AND_MOVE(ExceptionCallback); virtual ~ExceptionCallback() noexcept(false); virtual void onRecoverableException(Exception&& exception); @@ -275,6 +290,20 @@ Maybe runCatchingExceptions(Func&& func); // If exception are disabled (e.g. with -fno-exceptions), this will still detect whether any // recoverable exceptions occurred while running the function and will return those. +#if !KJ_NO_EXCEPTIONS + +kj::Exception getCaughtExceptionAsKj(); +// Call from the catch block of a try/catch to get a `kj::Exception` representing the exception +// that was caught, the same way that `kj::runCatchingExceptions` would when catching an exception. +// This is sometimes useful if `runCatchingExceptions()` doesn't quite fit your use case. You can +// call this from any catch block, including `catch (...)`. +// +// Some exception types will actually be rethrown by this function, rather than returned. The most +// common example is `CanceledException`, whose purpose is to unwind the stack and is not meant to +// be caught. + +#endif // !KJ_NO_EXCEPTIONS + class UnwindDetector { // Utility for detecting when a destructor is called due to unwind. Useful for: // - Avoiding throwing exceptions in this case, which would terminate the program. @@ -301,9 +330,13 @@ class UnwindDetector { private: uint uncaughtCount; - void catchExceptionsAsSecondaryFaults(_::Runnable& runnable) const; +#if !KJ_NO_EXCEPTIONS + void catchThrownExceptionAsSecondaryFault() const; +#endif }; +#if KJ_NO_EXCEPTIONS + namespace _ { // private class Runnable { @@ -326,20 +359,39 @@ Maybe runCatchingExceptions(Runnable& runnable); } // namespace _ (private) +#endif // KJ_NO_EXCEPTIONS + template Maybe runCatchingExceptions(Func&& func) { +#if KJ_NO_EXCEPTIONS _::RunnableImpl runnable(kj::fwd(func)); return _::runCatchingExceptions(runnable); +#else + try { + func(); + return nullptr; + } catch (...) { + return getCaughtExceptionAsKj(); + } +#endif } template void UnwindDetector::catchExceptionsIfUnwinding(Func&& func) const { +#if KJ_NO_EXCEPTIONS + // Can't possibly be unwinding... + func(); +#else if (isUnwinding()) { - _::RunnableImpl> runnable(kj::fwd(func)); - catchExceptionsAsSecondaryFaults(runnable); + try { + func(); + } catch (...) { + catchThrownExceptionAsSecondaryFault(); + } } else { func(); } +#endif } #define KJ_ON_SCOPE_SUCCESS(code) \ @@ -388,6 +440,9 @@ void printStackTraceOnCrash(); // a stack trace. You should call this as early as possible on program startup. Programs using // KJ_MAIN get this automatically. +void resetCrashHandlers(); +// Resets all signal handlers set by printStackTraceOnCrash(). + kj::StringPtr trimSourceFilename(kj::StringPtr filename); // Given a source code file name, trim off noisy prefixes like "src/" or // "/ekam-provider/canonical/". @@ -440,6 +495,10 @@ kj::ArrayPtr computeRelativeTrace( // // This is useful for debugging, when reporting several related traces at once. +void requireOnStack(void* ptr, kj::StringPtr description); +// Throw an exception if `ptr` does not appear to point to something near the top of the stack. +// Used as a safety check for types that must be stack-allocated, like ExceptionCallback. + } // namespace kj KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-test.c++ index d1d9fa2c98b..5e7596efad5 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-test.c++ @@ -19,10 +19,18 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _FILE_OFFSET_BITS +#define _FILE_OFFSET_BITS 64 +// Request 64-bit off_t and ino_t, otherwise this code will break when either value exceeds 2^32. +#endif + +#include "debug.h" #include "filesystem.h" +#include "string.h" #include "test.h" #include "encoding.h" #include +#include #if _WIN32 #include #include "windows-sanity.h" @@ -207,16 +215,19 @@ bool isWine() { return false; } #endif static Own newTempFile() { - char filename[] = VAR_TMP "/kj-filesystem-test.XXXXXX"; + const char* tmpDir = getenv("TEST_TMPDIR"); + auto filename = str(tmpDir != nullptr ? tmpDir : VAR_TMP, "/kj-filesystem-test.XXXXXX"); int fd; - KJ_SYSCALL(fd = mkstemp(filename)); - KJ_DEFER(KJ_SYSCALL(unlink(filename))); + KJ_SYSCALL(fd = mkstemp(filename.begin())); + KJ_DEFER(KJ_SYSCALL(unlink(filename.cStr()))); return newDiskFile(AutoCloseFd(fd)); } class TempDir { public: - TempDir(): filename(heapString(VAR_TMP "/kj-filesystem-test.XXXXXX")) { + TempDir() { + const char* tmpDir = getenv("TEST_TMPDIR"); + filename = str(tmpDir != nullptr ? tmpDir : VAR_TMP, "/kj-filesystem-test.XXXXXX"); if (mkdtemp(filename.begin()) == nullptr) { KJ_FAIL_SYSCALL("mkdtemp", errno, filename); } @@ -879,9 +890,17 @@ KJ_TEST("DiskFile holes") { // Some filesystems, like BTRFS, report zero `spaceUsed` until synced. file->datasync(); - // Allow for block sizes as low as 512 bytes and as high as 64k. + // Allow for block sizes as low as 512 bytes and as high as 64k. Since we wrote two locations, + // two blocks should be used. auto meta = file->stat(); +#if __FreeBSD__ + // On FreeBSD with ZFS it seems to report 512 bytes used even if I write more than 512 random + // (i.e. non-compressible) bytes. I couldn't figure it out so I'm giving up for now. Maybe it's + // a bug in the system? + KJ_EXPECT(meta.spaceUsed >= 512, meta.spaceUsed); +#else KJ_EXPECT(meta.spaceUsed >= 2 * 512, meta.spaceUsed); +#endif KJ_EXPECT(meta.spaceUsed <= 2 * 65536); byte buf[7]; @@ -935,9 +954,10 @@ KJ_TEST("DiskFile holes") { #endif file->zero(1 << 20, blockSize); file->datasync(); -#if !_WIN32 +#if !_WIN32 && !__FreeBSD__ // TODO(someday): This doesn't work on Windows. I don't know why. We're definitely using the - // proper ioctl. Oh well. + // proper ioctl. Oh well. It also doesn't work on FreeBSD-ZFS, due to the bug(?) mentioned + // earlier -- the size is just always reported as 512. KJ_EXPECT(file->stat().spaceUsed < meta.spaceUsed); #endif KJ_EXPECT(file->read(1 << 20, buf) == 7); @@ -945,5 +965,34 @@ KJ_TEST("DiskFile holes") { } #endif +#if !_WIN32 // Only applies to Unix. +// Ensure the current path is correctly computed. +// +// See issue #1425. +KJ_TEST("DiskFilesystem::computeCurrentPath") { + TempDir tempDir; + auto dir = tempDir.get(); + + // Paths can be PATH_MAX, but the segments which make up that path typically + // can't exceed 255 bytes. + auto maxPathSegment = std::string(255, 'a'); + + // Create a path which exceeds the 256 byte buffer used in + // computeCurrentPath. + auto subdir = dir->openSubdir(Path({ + maxPathSegment, + maxPathSegment, + "some_path_longer_than_256_bytes" + }), WriteMode::CREATE | WriteMode::CREATE_PARENT); + + auto origDir = open(".", O_RDONLY); + KJ_SYSCALL(fchdir(KJ_ASSERT_NONNULL(subdir->getFd()))); + KJ_DEFER(KJ_SYSCALL(fchdir(origDir))); + + // Test computeCurrentPath indirectly. + newDiskFilesystem(); +} +#endif + } // namespace } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-unix.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-unix.c++ index 8c9336238d4..67d7bf22c76 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-unix.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-unix.c++ @@ -25,6 +25,12 @@ #define _GNU_SOURCE #endif +#ifndef _FILE_OFFSET_BITS +#define _FILE_OFFSET_BITS 64 +// Request 64-bit off_t. (The code will still work if we get 32-bit off_t as long as actual files +// are under 4GB.) +#endif + #include "filesystem.h" #include "debug.h" #include @@ -182,7 +188,7 @@ static void rmrfChildrenAndClose(int fd) { if (entry->d_type == DT_DIR) { int subdirFd; KJ_SYSCALL(subdirFd = openat( - fd, entry->d_name, O_RDONLY | MAYBE_O_DIRECTORY | MAYBE_O_CLOEXEC)); + fd, entry->d_name, O_RDONLY | MAYBE_O_DIRECTORY | MAYBE_O_CLOEXEC | O_NOFOLLOW)); rmrfChildrenAndClose(subdirFd); KJ_SYSCALL(unlinkat(fd, entry->d_name, AT_REMOVEDIR)); } else if (entry->d_type != DT_UNKNOWN) { @@ -211,7 +217,9 @@ static bool rmrf(int fd, StringPtr path) { if (S_ISDIR(stats.st_mode)) { int subdirFd; KJ_SYSCALL(subdirFd = openat( - fd, path.cStr(), O_RDONLY | MAYBE_O_DIRECTORY | MAYBE_O_CLOEXEC)) { return false; } + fd, path.cStr(), O_RDONLY | MAYBE_O_DIRECTORY | MAYBE_O_CLOEXEC | O_NOFOLLOW)) { + return false; + } rmrfChildrenAndClose(subdirFd); KJ_SYSCALL(unlinkat(fd, path.cStr(), AT_REMOVEDIR)) { return false; } } else { @@ -298,6 +306,11 @@ public: return fd.get(); } + void setFd(AutoCloseFd newFd) { + // Used for one hack in DiskFilesystem's constructor... + fd = kj::mv(newFd); + } + // FsNode -------------------------------------------------------------------- FsNode::Metadata stat() const { @@ -1090,7 +1103,7 @@ public: } } -#if __linux__ && defined(RENAME_EXCHANGE) +#if __linux__ && defined(RENAME_EXCHANGE) && defined(SYS_renameat2) // Try to use Linux's renameat2() to atomically check preconditions and apply. if (has(mode, WriteMode::MODIFY)) { @@ -1111,7 +1124,7 @@ public: // Presumably because the target path doesn't exist. if (has(mode, WriteMode::CREATE)) { KJ_FAIL_ASSERT("rename(tmp, path) claimed path exists but " - "renameat2(fromPath, toPath, EXCAHNGE) said it doest; concurrent modification?", + "renameat2(fromPath, toPath, EXCHANGE) said it doest; concurrent modification?", fromPath, toPath) { return false; } } else { // Assume target doesn't exist. @@ -1650,7 +1663,25 @@ public: DiskFilesystem() : root(openDir("/")), current(openDir(".")), - currentPath(computeCurrentPath()) {} + currentPath(computeCurrentPath()) { + // We sometimes like to use qemu-user to test arm64 binaries cross-compiled from an x64 host + // machine. But, because it intercepts and rewrites system calls from userspace rather than + // emulating a whole kernel, it has a lot of quirks. One quirk that hits kj::Filesystem pretty + // badly is that open("/") actually returns a file descriptor for "/usr/aarch64-linux-gnu". + // Attempts to openat() any files within there then don't work. We can detect this problem and + // correct for it here. + struct stat realRoot, fsRoot; + KJ_SYSCALL_HANDLE_ERRORS(stat("/dev/..", &realRoot)) { + default: + // stat("/dev/..") failed? Give up. + return; + } + KJ_SYSCALL(fstat(root.DiskHandle::getFd(), &fsRoot)); + if (realRoot.st_ino != fsRoot.st_ino) { + KJ_LOG(WARNING, "root dir file descriptor is broken, probably because of qemu; compensating"); + root.setFd(openDir("/dev/..")); + } + } const Directory& getRoot() const override { return root; @@ -1710,7 +1741,7 @@ private: KJ_STACK_ARRAY(char, buf, size, 256, 4096); if (getcwd(buf.begin(), size) == nullptr) { int error = errno; - if (error == ENAMETOOLONG) { + if (error == ERANGE) { size *= 2; goto retry; } else { diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-win32.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-win32.c++ index 7f3442beaed..be761894f69 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-win32.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem-disk-win32.c++ @@ -158,6 +158,17 @@ static void rmrfChildren(ArrayPtr path) { auto glob = join16(path, L"*"); WIN32_FIND_DATAW data; + // TODO(security): If `path` is a reparse point (symlink), this will follow it and delete the + // contents. We check for reparse points before recursing, but there is still a TOCTOU race + // condition. + // + // Apparently, there is a whole different directory-listing API we could be using here: + // `GetFileInformationByHandleEx()`, with the `FileIdBothDirectoryInfo` flag. This lets us + // list the contents of a directory from its already-open handle -- it's probably how we should + // do directory listing in general! If we open a file with FILE_FLAG_OPEN_REPARSE_POINT, then + // the handle will represent the reparse point itself, and attempting to list it will produce + // no entries. I had no idea this API existed when I wrote much of this code; I wish I had + // because it seems much cleaner than the ancient FindFirstFile/FindNextFile API! HANDLE handle = FindFirstFileW(glob.begin(), &data); if (handle == INVALID_HANDLE_VALUE) { auto error = GetLastError(); @@ -575,8 +586,8 @@ public: PathPtr path = KJ_ASSERT_NONNULL(dirPath); auto glob = join16(path.forWin32Api(true), L"*"); - // TODO(perf): Use FindFileEx() with FindExInfoBasic? Not apparently supported on Vista. - // TODO(someday): Use NtQueryDirectoryObject() instead? It's "internal", but so much cleaner. + // TODO(someday): Use GetFileInformationByHandleEx() with FileIdBothDirectoryInfo to enumerate + // directories instead. It's much cleaner. WIN32_FIND_DATAW data; HANDLE handle = FindFirstFileW(glob.begin(), &data); if (handle == INVALID_HANDLE_VALUE) { @@ -674,7 +685,7 @@ public: nativePath(path).begin(), GENERIC_READ, // When opening directories, we do NOT use FILE_SHARE_DELETE, because we need the directory - // path to remain vaild. + // path to remain valid. // // TODO(someday): Use NtCreateFile() and related "internal" APIs that allow for // openat()-like behavior? @@ -932,7 +943,7 @@ public: NULL)) { case ERROR_PATH_NOT_FOUND: if (has(mode, WriteMode::CREATE)) { - // A parent directory didn't exist. Maybe cerate it. + // A parent directory didn't exist. Maybe create it. if (has(mode, WriteMode::CREATE_PARENT) && path.size() > 0 && tryMkdir(path.parent(), WriteMode::CREATE | WriteMode::MODIFY | WriteMode::CREATE_PARENT, true)) { @@ -1233,7 +1244,7 @@ public: // We can't really create symlinks on Windows. Reasons: // - We'd need to know whether the target is a file or a directory to pass the correct flags. // That means we'd need to evaluate the link content and track down the target. What if the - // taget doesn't exist? It's unclear if this is even allowed on Windows. + // target doesn't exist? It's unclear if this is even allowed on Windows. // - Apparently, creating symlinks is a privileged operation on Windows prior to Windows 10. // The flag SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE is very new. KJ_UNIMPLEMENTED( @@ -1286,7 +1297,7 @@ public: case ERROR_PATH_NOT_FOUND: return false; case ERROR_ACCESS_DENIED: - // This usually means that fromPath was a directory or toPath was a direcotry. Fall back + // This usually means that fromPath was a directory or toPath was a directory. Fall back // to default implementation. break; default: diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem.c++ index 62b944cf86e..1dff22ba21f 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem.c++ @@ -539,7 +539,7 @@ FsNode::Metadata ReadableDirectory::lstat(PathPtr path) const { KJ_IF_MAYBE(meta, tryLstat(path)) { return *meta; } else { - KJ_FAIL_REQUIRE("no such file", path) { break; } + KJ_FAIL_REQUIRE("no such file or directory", path) { break; } return FsNode::Metadata(); } } @@ -548,7 +548,7 @@ Own ReadableDirectory::openFile(PathPtr path) const { KJ_IF_MAYBE(file, tryOpenFile(path)) { return kj::mv(*file); } else { - KJ_FAIL_REQUIRE("no such directory", path) { break; } + KJ_FAIL_REQUIRE("no such file", path) { break; } return newInMemoryFile(nullClock()); } } @@ -557,7 +557,7 @@ Own ReadableDirectory::openSubdir(PathPtr path) const { KJ_IF_MAYBE(dir, tryOpenSubdir(path)) { return kj::mv(*dir); } else { - KJ_FAIL_REQUIRE("no such file or directory", path) { break; } + KJ_FAIL_REQUIRE("no such directory", path) { break; } return newInMemoryDirectory(nullClock()); } } diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem.h b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem.h index de309c4114f..323420a4421 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/filesystem.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/filesystem.h @@ -28,6 +28,8 @@ #include "function.h" #include "hash.h" +KJ_BEGIN_HEADER + namespace kj { template @@ -359,7 +361,7 @@ class FsNode { uint64_t hashCode = 0; // Hint which can be used to determine if two FsNode instances point to the same underlying // file object. If two FsNodes report different hashCodes, then they are not the same object. - // If they report the same hashCode, then they may or may not be teh same object. + // If they report the same hashCode, then they may or may not be the same object. // // The Unix filesystem implementation builds the hashCode based on st_dev and st_ino of // `struct stat`. However, note that some filesystems -- especially FUSE-based -- may not fill @@ -880,6 +882,13 @@ class Directory: public ReadableDirectory { // tryRemove() returns false in the specific case that the path doesn't exist. remove() would // throw in this case. In all other error cases (like "access denied"), tryRemove() still throws; // it is only "does not exist" that produces a false return. + // + // WARNING: The Windows implementation of recursive deletion is currently not safe to call from a + // privileged process to delete directories writable by unprivileged users, due to a race + // condition in which the user could trick the algorithm into following a symlink and deleting + // everything at the destination. This race condition is not present in the Unix + // implementation. Fixing it for Windows would require rewriting a lot of code to use different + // APIs. If you're interested, see the TODO(security) in filesystem-disk-win32.c++. // TODO(someday): // - Support sockets? There's no openat()-like interface for sockets, so it's hard to support @@ -938,7 +947,7 @@ Own newInMemoryDirectory(const Clock& clock); // which would expand it will throw. // // InMemoryDirectory has the following special properties: -// - Symlinks are processed using Path::parse(). This implies tha a symlink cannot point to a +// - Symlinks are processed using Path::parse(). This implies that a symlink cannot point to a // parent directory -- InMemoryDirectory does not know its parent. // - link() can link directory nodes in addition to files. // - link() and rename() accept any kind of Directory as `fromDirectory` -- it doesn't need to be @@ -1110,3 +1119,5 @@ void Directory::Replacer::commit() { } } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/hash.h b/libs/EXTERNAL/capnproto/c++/src/kj/hash.h index 750a14b9580..d6ff46fd817 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/hash.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/hash.h @@ -50,6 +50,7 @@ struct HashCoder { inline uint operator*(const Array& s) const { return operator*(s.asBytes()); } inline uint operator*(const String& s) const { return operator*(s.asBytes()); } inline uint operator*(const StringPtr& s) const { return operator*(s.asBytes()); } + inline uint operator*(const ConstString& s) const { return operator*(s.asBytes()); } inline uint operator*(decltype(nullptr)) const { return 0; } inline uint operator*(bool b) const { return b; } @@ -90,6 +91,8 @@ struct HashCoder { template uint operator*(T* ptr) const { + static_assert(!isSameType, char>(), "Wrap in StringPtr if you want to hash string " + "contents. If you want to hash the pointer, cast to void*"); if (sizeof(ptr) == sizeof(uint)) { // TODO(cleanup): In C++17, make the if() above be `if constexpr ()`, then change this to // reinterpret_cast(ptr). @@ -128,6 +131,14 @@ static KJ_CONSTEXPR(const) HashCoder HASHCODER = HashCoder(); inline uint hashCode(uint value) { return value; } template inline uint hashCode(T&& value) { return hashCode(_::HASHCODER * kj::fwd(value)); } +template +inline uint hashCode(T (&arr)[N]) { + static_assert(!isSameType, char>(), "Wrap in StringPtr if you want to hash string " + "contents. If you want to hash the pointer, cast to void*"); + static_assert(isSameType, char>(), "Wrap in ArrayPtr if you want to hash a C array. " + "If you want to hash the pointer, cast to void*"); + return 0; +} template inline uint hashCode(T&&... values) { uint hashes[] = { hashCode(kj::fwd(values))... }; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/io-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/io-test.c++ index 1ec162e2ad3..ea7bac413b3 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/io-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/io-test.c++ @@ -111,6 +111,11 @@ KJ_TEST("VectorOutputStream") { KJ_ASSERT(output.getWriteBuffer().size() == 24); KJ_ASSERT(output.getWriteBuffer().begin() == output.getArray().begin() + 40); + + output.clear(); + KJ_ASSERT(output.getWriteBuffer().begin() == output.getArray().begin()); + KJ_ASSERT(output.getWriteBuffer().size() == 64); + KJ_ASSERT(output.getArray().size() == 0); } class MockInputStream: public InputStream { diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/io.h b/libs/EXTERNAL/capnproto/c++/src/kj/io.h index a09094983f6..3edc300ca5a 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/io.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/io.h @@ -141,7 +141,7 @@ class BufferedInputStreamWrapper: public BufferedInputStream { // If the second parameter is non-null, the stream uses the given buffer instead of allocating // its own. This may improve performance if the buffer can be reused. - KJ_DISALLOW_COPY(BufferedInputStreamWrapper); + KJ_DISALLOW_COPY_AND_MOVE(BufferedInputStreamWrapper); ~BufferedInputStreamWrapper() noexcept(false); // implements BufferedInputStream ---------------------------------- @@ -167,7 +167,7 @@ class BufferedOutputStreamWrapper: public BufferedOutputStream { // If the second parameter is non-null, the stream uses the given buffer instead of allocating // its own. This may improve performance if the buffer can be reused. - KJ_DISALLOW_COPY(BufferedOutputStreamWrapper); + KJ_DISALLOW_COPY_AND_MOVE(BufferedOutputStreamWrapper); ~BufferedOutputStreamWrapper() noexcept(false); void flush(); @@ -193,7 +193,7 @@ class BufferedOutputStreamWrapper: public BufferedOutputStream { class ArrayInputStream: public BufferedInputStream { public: explicit ArrayInputStream(ArrayPtr array); - KJ_DISALLOW_COPY(ArrayInputStream); + KJ_DISALLOW_COPY_AND_MOVE(ArrayInputStream); ~ArrayInputStream() noexcept(false); // implements BufferedInputStream ---------------------------------- @@ -208,7 +208,7 @@ class ArrayInputStream: public BufferedInputStream { class ArrayOutputStream: public BufferedOutputStream { public: explicit ArrayOutputStream(ArrayPtr array); - KJ_DISALLOW_COPY(ArrayOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(ArrayOutputStream); ~ArrayOutputStream() noexcept(false); ArrayPtr getArray() { @@ -228,7 +228,7 @@ class ArrayOutputStream: public BufferedOutputStream { class VectorOutputStream: public BufferedOutputStream { public: explicit VectorOutputStream(size_t initialCapacity = 4096); - KJ_DISALLOW_COPY(VectorOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(VectorOutputStream); ~VectorOutputStream() noexcept(false); ArrayPtr getArray() { @@ -236,6 +236,8 @@ class VectorOutputStream: public BufferedOutputStream { return arrayPtr(vector.begin(), fillPos); } + void clear() { fillPos = vector.begin(); } + // implements BufferedInputStream ---------------------------------- ArrayPtr getWriteBuffer() override; void write(const void* buffer, size_t size) override; @@ -311,7 +313,7 @@ class FdInputStream: public InputStream { public: explicit FdInputStream(int fd): fd(fd) {} explicit FdInputStream(AutoCloseFd fd): fd(fd), autoclose(mv(fd)) {} - KJ_DISALLOW_COPY(FdInputStream); + KJ_DISALLOW_COPY_AND_MOVE(FdInputStream); ~FdInputStream() noexcept(false); size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; @@ -329,7 +331,7 @@ class FdOutputStream: public OutputStream { public: explicit FdOutputStream(int fd): fd(fd) {} explicit FdOutputStream(AutoCloseFd fd): fd(fd), autoclose(mv(fd)) {} - KJ_DISALLOW_COPY(FdOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(FdOutputStream); ~FdOutputStream() noexcept(false); void write(const void* buffer, size_t size) override; @@ -405,7 +407,7 @@ class HandleInputStream: public InputStream { public: explicit HandleInputStream(void* handle): handle(handle) {} explicit HandleInputStream(AutoCloseHandle handle): handle(handle), autoclose(mv(handle)) {} - KJ_DISALLOW_COPY(HandleInputStream); + KJ_DISALLOW_COPY_AND_MOVE(HandleInputStream); ~HandleInputStream() noexcept(false); size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; @@ -421,7 +423,7 @@ class HandleOutputStream: public OutputStream { public: explicit HandleOutputStream(void* handle): handle(handle) {} explicit HandleOutputStream(AutoCloseHandle handle): handle(handle), autoclose(mv(handle)) {} - KJ_DISALLOW_COPY(HandleOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(HandleOutputStream); ~HandleOutputStream() noexcept(false); void write(const void* buffer, size_t size) override; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/list-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/list-test.c++ index 0c7172de8bd..9286226e5ea 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/list-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/list-test.c++ @@ -39,6 +39,7 @@ KJ_TEST("List") { TestElement foo(123); TestElement bar(456); + TestElement baz(789); { list.add(foo); @@ -77,6 +78,26 @@ KJ_TEST("List") { ++iter; KJ_ASSERT(iter == clist.end()); } + + { + list.addFront(baz); + KJ_EXPECT(list.size() == 3); + KJ_DEFER(list.remove(baz)); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 789); + ++iter; + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 123); + ++iter; + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 321); + ++iter; + KJ_ASSERT(iter == list.end()); + } + } } KJ_EXPECT(list.size() == 1); @@ -97,7 +118,7 @@ KJ_TEST("List") { KJ_EXPECT(list.size() == 0); { - list.add(bar); + list.addFront(bar); KJ_DEFER(list.remove(bar)); KJ_EXPECT(!list.empty()); KJ_EXPECT(list.size() == 1); @@ -110,6 +131,23 @@ KJ_TEST("List") { ++iter; KJ_ASSERT(iter == list.end()); } + + { + list.add(baz); + KJ_EXPECT(list.size() == 2); + KJ_DEFER(list.remove(baz)); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 321); + ++iter; + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 789); + ++iter; + KJ_ASSERT(iter == list.end()); + } + } } KJ_EXPECT(list.empty()); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/list.h b/libs/EXTERNAL/capnproto/c++/src/kj/list.h index 02b8cdb39e9..4575b0f96ea 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/list.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/list.h @@ -71,7 +71,7 @@ class List { // // Note that you MUST manually remove an element from the list before destroying it. ListLinks // do not automatically unlink themselves because this could lead to subtle thread-safety bugs - // if the List is guarded by a mutex, and that mutex is not currenty locked. Normally, you should + // if the List is guarded by a mutex, and that mutex is not currently locked. Normally, you should // have T's destructor remove it from any lists. You can use `link.isLinked()` to check if the // item is currently in a list. // @@ -84,7 +84,7 @@ class List { public: List() = default; - KJ_DISALLOW_COPY(List); + KJ_DISALLOW_COPY_AND_MOVE(List); bool empty() const { return head == nullptr; @@ -102,6 +102,19 @@ class List { ++listSize; } + void addFront(T& element) { + if ((element.*link).prev != nullptr) _::throwDoubleAdd(); + (element.*link).next = head; + (element.*link).prev = &head; + KJ_IF_MAYBE(oldHead, head) { + (oldHead->*link).prev = &(element.*link).next; + } else { + tail = &(element.*link).next; + } + head = element; + ++listSize; + } + void remove(T& element) { if ((element.*link).prev == nullptr) _::throwRemovedNotPresent(); *((element.*link).prev) = (element.*link).next; @@ -141,7 +154,7 @@ class ListLink { // Intentionally `noexcept` because we want to crash if a dangling pointer was left in a list. if (prev != nullptr) _::throwDestroyedWhileInList(); } - KJ_DISALLOW_COPY(ListLink); + KJ_DISALLOW_COPY_AND_MOVE(ListLink); bool isLinked() const { return prev != nullptr; } diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/map-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/map-test.c++ index ac2b2410e3c..42b5846e600 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/map-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/map-test.c++ @@ -29,7 +29,9 @@ namespace { KJ_TEST("HashMap") { HashMap map; - map.insert(kj::str("foo"), 123); + kj::String ownFoo = kj::str("foo"); + const char* origFoo = ownFoo.begin(); + map.insert(kj::mv(ownFoo), 123); map.insert(kj::str("bar"), 456); KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 123); @@ -39,10 +41,16 @@ KJ_TEST("HashMap") { map.upsert(kj::str("foo"), 789, [](int& old, uint newValue) { KJ_EXPECT(old == 123); KJ_EXPECT(newValue == 789); - old = 321; + old = 4321; }); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 4321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.findEntry("foo"_kj)).key.begin() == origFoo); + + map.upsert(kj::str("foo"), 321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.findEntry("foo"_kj)).key.begin() == origFoo); KJ_EXPECT( map.findOrCreate("foo"_kj, @@ -70,7 +78,9 @@ KJ_TEST("HashMap") { KJ_TEST("TreeMap") { TreeMap map; - map.insert(kj::str("foo"), 123); + kj::String ownFoo = kj::str("foo"); + const char* origFoo = ownFoo.begin(); + map.insert(kj::mv(ownFoo), 123); map.insert(kj::str("bar"), 456); KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 123); @@ -80,10 +90,16 @@ KJ_TEST("TreeMap") { map.upsert(kj::str("foo"), 789, [](int& old, uint newValue) { KJ_EXPECT(old == 123); KJ_EXPECT(newValue == 789); - old = 321; + old = 4321; }); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 4321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.findEntry("foo"_kj)).key.begin() == origFoo); + + map.upsert(kj::str("foo"), 321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.findEntry("foo"_kj)).key.begin() == origFoo); KJ_EXPECT( map.findOrCreate("foo"_kj, diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/map.h b/libs/EXTERNAL/capnproto/c++/src/kj/map.h index bbd2058a01d..4f92a2034e9 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/map.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/map.h @@ -68,8 +68,10 @@ class HashMap { template Entry& upsert(Key key, Value value, UpdateFunc&& update); + Entry& upsert(Key key, Value value); // Tries to insert a new entry. However, if a duplicate already exists (according to some index), // then update(Value& existingValue, Value&& newValue) is called to modify the existing value. + // If no function is provided, the default is to simply replace the value (but not the key). template kj::Maybe find(KeyLike&& key); @@ -97,12 +99,15 @@ class HashMap { bool erase(KeyLike&& key); // Erase the entry with the matching key. // - // WARNING: This invalidates all pointers and interators into the map. Use eraseAll() if you need + // WARNING: This invalidates all pointers and iterators into the map. Use eraseAll() if you need // to iterate and erase multiple entries. void erase(Entry& entry); // Erase an entry by reference. + Entry release(Entry& row); + // Erase an entry and return its content by move. + template ()(instance(), instance()))> size_t eraseAll(Predicate&& predicate); @@ -167,8 +172,10 @@ class TreeMap { template Entry& upsert(Key key, Value value, UpdateFunc&& update); + Entry& upsert(Key key, Value value); // Tries to insert a new entry. However, if a duplicate already exists (according to some index), // then update(Value& existingValue, Value&& newValue) is called to modify the existing value. + // If no function is provided, the default is to simply replace the value (but not the key). template kj::Maybe find(KeyLike&& key); @@ -200,12 +207,15 @@ class TreeMap { bool erase(KeyLike&& key); // Erase the entry with the matching key. // - // WARNING: This invalidates all pointers and interators into the map. Use eraseAll() if you need + // WARNING: This invalidates all pointers and iterators into the map. Use eraseAll() if you need // to iterate and erase multiple entries. void erase(Entry& entry); // Erase an entry by reference. + Entry release(Entry& row); + // Erase an entry and return its content by move. + template ()(instance(), instance()))> size_t eraseAll(Predicate&& predicate); @@ -350,6 +360,15 @@ typename HashMap::Entry& HashMap::upsert( }); } +template +typename HashMap::Entry& HashMap::upsert( + Key key, Value value) { + return table.upsert(Entry { kj::mv(key), kj::mv(value) }, + [&](Entry& existingEntry, Entry&& newEntry) { + existingEntry.value = kj::mv(newEntry.value); + }); +} + template template kj::Maybe HashMap::find(KeyLike&& key) { @@ -397,6 +416,11 @@ void HashMap::erase(Entry& entry) { table.erase(entry); } +template +typename HashMap::Entry HashMap::release(Entry& entry) { + return table.release(entry); +} + template template size_t HashMap::eraseAll(Predicate&& predicate) { @@ -463,6 +487,15 @@ typename TreeMap::Entry& TreeMap::upsert( }); } +template +typename TreeMap::Entry& TreeMap::upsert( + Key key, Value value) { + return table.upsert(Entry { kj::mv(key), kj::mv(value) }, + [&](Entry& existingEntry, Entry&& newEntry) { + existingEntry.value = kj::mv(newEntry.value); + }); +} + template template kj::Maybe TreeMap::find(KeyLike&& key) { @@ -521,6 +554,11 @@ void TreeMap::erase(Entry& entry) { table.erase(entry); } +template +typename TreeMap::Entry TreeMap::release(Entry& entry) { + return table.release(entry); +} + template template size_t TreeMap::eraseAll(Predicate&& predicate) { diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/memory-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/memory-test.c++ index 6e1e343232b..96ec0f58321 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/memory-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/memory-test.c++ @@ -225,7 +225,7 @@ struct SingularDerivedDynamic final: public DynamicType1 { ~SingularDerivedDynamic() { destructorCalled = true; } - KJ_DISALLOW_COPY(SingularDerivedDynamic); + KJ_DISALLOW_COPY_AND_MOVE(SingularDerivedDynamic); bool& destructorCalled; }; @@ -238,7 +238,7 @@ struct MultipleDerivedDynamic final: public DynamicType1, public DynamicType2 { destructorCalled = true; } - KJ_DISALLOW_COPY(MultipleDerivedDynamic); + KJ_DISALLOW_COPY_AND_MOVE(MultipleDerivedDynamic); bool& destructorCalled; }; @@ -295,6 +295,14 @@ TEST(Memory, OwnVoid) { voidPtr = nullptr; KJ_EXPECT(destructorCalled); } + + { + Maybe> maybe; + maybe = Own(&maybe, NullDisposer::instance); + KJ_EXPECT(KJ_ASSERT_NONNULL(maybe).get() == &maybe); + maybe = nullptr; + KJ_EXPECT(maybe == nullptr); + } } TEST(Memory, OwnConstVoid) { @@ -349,6 +357,14 @@ TEST(Memory, OwnConstVoid) { voidPtr = nullptr; KJ_EXPECT(destructorCalled); } + + { + Maybe> maybe; + maybe = Own(&maybe, NullDisposer::instance); + KJ_EXPECT(KJ_ASSERT_NONNULL(maybe).get() == &maybe); + maybe = nullptr; + KJ_EXPECT(maybe == nullptr); + } } struct IncompleteType; @@ -395,6 +411,95 @@ KJ_TEST("Own") { } } +KJ_TEST("Own with static disposer") { + static int* disposedPtr = nullptr; + struct MyDisposer { + static void dispose(int* value) { + KJ_EXPECT(disposedPtr == nullptr); + disposedPtr = value; + }; + }; + + int i; + + { + Own ptr(&i); + KJ_EXPECT(disposedPtr == nullptr); + } + KJ_EXPECT(disposedPtr == &i); + disposedPtr = nullptr; + + { + Own ptr(&i); + KJ_EXPECT(disposedPtr == nullptr); + Own ptr2(kj::mv(ptr)); + KJ_EXPECT(disposedPtr == nullptr); + } + KJ_EXPECT(disposedPtr == &i); + disposedPtr = nullptr; + + { + Own ptr2; + { + Own ptr(&i); + KJ_EXPECT(disposedPtr == nullptr); + ptr2 = kj::mv(ptr); + KJ_EXPECT(disposedPtr == nullptr); + } + KJ_EXPECT(disposedPtr == nullptr); + } + KJ_EXPECT(disposedPtr == &i); +} + +KJ_TEST("Maybe>") { + Maybe> m = heap(123); + KJ_EXPECT(m != nullptr); + Maybe mRef = m; + KJ_EXPECT(KJ_ASSERT_NONNULL(mRef) == 123); + KJ_EXPECT(&KJ_ASSERT_NONNULL(mRef) == KJ_ASSERT_NONNULL(m).get()); +} + +#if KJ_CPP_STD > 201402L +int* sawIntPtr = nullptr; + +void freeInt(int* ptr) { + sawIntPtr = ptr; + delete ptr; +} + +void freeChar(char* c) { + delete c; +} + +void free(StaticType* ptr) { + delete ptr; +} + +void free(const char* ptr) {} + +KJ_TEST("disposeWith") { + auto i = new int(1); + { + auto p = disposeWith(i); + KJ_EXPECT(sawIntPtr == nullptr); + } + KJ_EXPECT(sawIntPtr == i); + { + auto c = new char('a'); + auto p = disposeWith(c); + } + { + // Explicit cast required to avoid ambiguity when overloads are present. + auto s = new StaticType{1}; + auto p = disposeWith(free)>(s); + } + { + const char c = 'a'; + auto p2 = disposeWith(free)>(&c); + } +} +#endif + // TODO(test): More tests. } // namespace diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/memory.h b/libs/EXTERNAL/capnproto/c++/src/kj/memory.h index 1229b5c3ecf..6b004f988c7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/memory.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/memory.h @@ -171,8 +171,11 @@ class NullDisposer: public Disposer { // ======================================================================================= // Own -- An owned pointer. +template +class Own; + template -class Own { +class Own { // A transferrable title to a T. When an Own goes out of scope, the object's Disposer is // called to dispose of it. An Own can be efficiently passed by move, without relocating the // underlying object; this transfers ownership. @@ -199,6 +202,9 @@ class Own { : disposer(other.disposer), ptr(cast(other.ptr)) { other.ptr = nullptr; } + template ()>> + inline Own(Own&& other) noexcept; + // Convert statically-disposed Own to dynamically-disposed Own. inline Own(T* ptr, const Disposer& disposer) noexcept: disposer(&disposer), ptr(ptr) {} ~Own() noexcept(false) { dispose(); } @@ -286,7 +292,7 @@ class Own { return ptr; } - template + template friend class Own; friend class Maybe>; }; @@ -303,72 +309,180 @@ inline const void* Own::cast(U* ptr) { return _::castToConstVoid(ptr); } +template +class Own { + // If a `StaticDisposer` is specified (which is not the norm), then the object will be deleted + // by calling StaticDisposer::dispose(pointer). The pointer passed to `dispose()` could be a + // superclass of `T`, if the pointer has been upcast. + // + // This type can be useful for micro-optimization, if you've found that you are doing excessive + // heap allocations to the point where the virtual call on destruction is costing non-negligible + // resources. You should avoid this unless you have a specific need, because it precludes a lot + // of power. + +public: + KJ_DISALLOW_COPY(Own); + inline Own(): ptr(nullptr) {} + inline Own(Own&& other) noexcept + : ptr(other.ptr) { other.ptr = nullptr; } + inline Own(Own, StaticDisposer>&& other) noexcept + : ptr(other.ptr) { other.ptr = nullptr; } + template ()>> + inline Own(Own&& other) noexcept + : ptr(cast(other.ptr)) { + other.ptr = nullptr; + } + inline explicit Own(T* ptr) noexcept: ptr(ptr) {} + + ~Own() noexcept(false) { dispose(); } + + inline Own& operator=(Own&& other) { + // Move-assignnment operator. + + // Careful, this might own `other`. Therefore we have to transfer the pointers first, then + // dispose. + T* ptrCopy = ptr; + ptr = other.ptr; + other.ptr = nullptr; + if (ptrCopy != nullptr) { + StaticDisposer::dispose(ptrCopy); + } + return *this; + } + + inline Own& operator=(decltype(nullptr)) { + dispose(); + return *this; + } + + template + Own downcast() { + // Downcast the pointer to Own, destroying the original pointer. If this pointer does not + // actually point at an instance of U, the results are undefined (throws an exception in debug + // mode if RTTI is enabled, otherwise you're on your own). + + Own result; + if (ptr != nullptr) { + result.ptr = &kj::downcast(*ptr); + ptr = nullptr; + } + return result; + } + +#define NULLCHECK KJ_IREQUIRE(ptr != nullptr, "null Own<> dereference") + inline T* operator->() { NULLCHECK; return ptr; } + inline const T* operator->() const { NULLCHECK; return ptr; } + inline _::RefOrVoid operator*() { NULLCHECK; return *ptr; } + inline _::RefOrVoid operator*() const { NULLCHECK; return *ptr; } +#undef NULLCHECK + inline T* get() { return ptr; } + inline const T* get() const { return ptr; } + inline operator T*() { return ptr; } + inline operator const T*() const { return ptr; } + +private: + T* ptr; + + inline explicit Own(decltype(nullptr)): ptr(nullptr) {} + + inline bool operator==(decltype(nullptr)) { return ptr == nullptr; } + inline bool operator!=(decltype(nullptr)) { return ptr != nullptr; } + // Only called by Maybe>. + + inline void dispose() { + // Make sure that if an exception is thrown, we are left with a null ptr, so we won't possibly + // dispose again. + T* ptrCopy = ptr; + if (ptrCopy != nullptr) { + ptr = nullptr; + StaticDisposer::dispose(ptrCopy); + } + } + + template + static inline T* cast(U* ptr) { + return ptr; + } + + template + friend class Own; + friend class Maybe>; +}; + namespace _ { // private -template +template class OwnOwn { public: - inline OwnOwn(Own&& value) noexcept: value(kj::mv(value)) {} + inline OwnOwn(Own&& value) noexcept: value(kj::mv(value)) {} - inline Own& operator*() & { return value; } - inline const Own& operator*() const & { return value; } - inline Own&& operator*() && { return kj::mv(value); } - inline const Own&& operator*() const && { return kj::mv(value); } - inline Own* operator->() { return &value; } - inline const Own* operator->() const { return &value; } - inline operator Own*() { return value ? &value : nullptr; } - inline operator const Own*() const { return value ? &value : nullptr; } + inline Own& operator*() & { return value; } + inline const Own& operator*() const & { return value; } + inline Own&& operator*() && { return kj::mv(value); } + inline const Own&& operator*() const && { return kj::mv(value); } + inline Own* operator->() { return &value; } + inline const Own* operator->() const { return &value; } + inline operator Own*() { return value ? &value : nullptr; } + inline operator const Own*() const { return value ? &value : nullptr; } private: - Own value; + Own value; }; -template -OwnOwn readMaybe(Maybe>&& maybe) { return OwnOwn(kj::mv(maybe.ptr)); } -template -Own* readMaybe(Maybe>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; } -template -const Own* readMaybe(const Maybe>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; } +template +OwnOwn readMaybe(Maybe>&& maybe) { return OwnOwn(kj::mv(maybe.ptr)); } +template +Own* readMaybe(Maybe>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; } +template +const Own* readMaybe(const Maybe>& maybe) { + return maybe.ptr ? &maybe.ptr : nullptr; +} } // namespace _ (private) -template -class Maybe> { +template +class Maybe> { public: inline Maybe(): ptr(nullptr) {} - inline Maybe(Own&& t) noexcept: ptr(kj::mv(t)) {} + inline Maybe(Own&& t) noexcept: ptr(kj::mv(t)) {} inline Maybe(Maybe&& other) noexcept: ptr(kj::mv(other.ptr)) {} template - inline Maybe(Maybe>&& other): ptr(mv(other.ptr)) {} + inline Maybe(Maybe>&& other): ptr(mv(other.ptr)) {} template - inline Maybe(Own&& other): ptr(mv(other)) {} + inline Maybe(Own&& other): ptr(mv(other)) {} inline Maybe(decltype(nullptr)) noexcept: ptr(nullptr) {} - inline Own& emplace(Own value) { + inline Own& emplace(Own value) { // Assign the Maybe to the given value and return the content. This avoids the need to do a // KJ_ASSERT_NONNULL() immediately after setting the Maybe just to read it back again. ptr = kj::mv(value); return ptr; } - inline operator Maybe() { return ptr.get(); } - inline operator Maybe() const { return ptr.get(); } + template + inline operator NoInfer>() { return ptr.get(); } + template + inline operator NoInfer>() const { return ptr.get(); } + // Implicit conversion to `Maybe`. The weird templating is to make sure that + // `Maybe>` can be instantiated with the compiler complaining about forming references + // to void -- the use of templates here will cause SFINAE to kick in and hide these, whereas if + // they are not templates then SFINAE isn't applied and so they are considered errors. inline Maybe& operator=(Maybe&& other) { ptr = kj::mv(other.ptr); return *this; } inline bool operator==(decltype(nullptr)) const { return ptr == nullptr; } inline bool operator!=(decltype(nullptr)) const { return ptr != nullptr; } - Own& orDefault(Own& defaultValue) { + Own& orDefault(Own& defaultValue) { if (ptr == nullptr) { return defaultValue; } else { return ptr; } } - const Own& orDefault(const Own& defaultValue) const { + const Own& orDefault(const Own& defaultValue) const { if (ptr == nullptr) { return defaultValue; } else { @@ -376,8 +490,18 @@ class Maybe> { } } + template () ? instance>() : instance()())> + Result orDefault(F&& lazyDefaultValue) && { + if (ptr == nullptr) { + return lazyDefaultValue(); + } else { + return kj::mv(ptr); + } + } + template - auto map(Func&& f) & -> Maybe&>()))> { + auto map(Func&& f) & -> Maybe&>()))> { if (ptr == nullptr) { return nullptr; } else { @@ -386,7 +510,7 @@ class Maybe> { } template - auto map(Func&& f) const & -> Maybe&>()))> { + auto map(Func&& f) const & -> Maybe&>()))> { if (ptr == nullptr) { return nullptr; } else { @@ -395,7 +519,7 @@ class Maybe> { } template - auto map(Func&& f) && -> Maybe&&>()))> { + auto map(Func&& f) && -> Maybe&&>()))> { if (ptr == nullptr) { return nullptr; } else { @@ -404,7 +528,7 @@ class Maybe> { } template - auto map(Func&& f) const && -> Maybe&&>()))> { + auto map(Func&& f) const && -> Maybe&&>()))> { if (ptr == nullptr) { return nullptr; } else { @@ -413,16 +537,16 @@ class Maybe> { } private: - Own ptr; + Own ptr; template friend class Maybe; - template - friend _::OwnOwn _::readMaybe(Maybe>&& maybe); - template - friend Own* _::readMaybe(Maybe>& maybe); - template - friend const Own* _::readMaybe(const Maybe>& maybe); + template + friend _::OwnOwn _::readMaybe(Maybe>&& maybe); + template + friend Own* _::readMaybe(Maybe>& maybe); + template + friend const Own* _::readMaybe(const Maybe>& maybe); }; namespace _ { // private @@ -447,6 +571,32 @@ template const HeapDisposer HeapDisposer::instance = HeapDisposer(); #endif +#if KJ_CPP_STD >= 202002L +template +class CustomDisposer: public Disposer { +public: + void disposeImpl(void* pointer) const override { + (*F)(reinterpret_cast(pointer)); + } +}; + +template +static constexpr CustomDisposer CUSTOM_DISPOSER_INSTANCE {}; +#else +template +class CustomDisposer: public Disposer { +public: + static const CustomDisposer instance; + + void disposeImpl(void* pointer) const override { + (*F)(reinterpret_cast(pointer)); + } +}; + +template +const CustomDisposer CustomDisposer::instance = CustomDisposer(); +#endif + } // namespace _ (private) template @@ -470,6 +620,26 @@ Own> heap(T&& orig) { return Own(new T2(kj::fwd(orig)), _::HeapDisposer::instance); } +#if KJ_CPP_STD > 201402L +#if KJ_CPP_STD < 202002L +template +Own disposeWith(T* ptr) { + // Associate a pre-allocated raw pointer with a corresponding disposal function. + // The first template parameter should be a function pointer e.g. disposeWith(new int(0)). + + return Own(ptr, _::CustomDisposer::instance); +} +#else +template +Own disposeWith(T* ptr) { + // Associate a pre-allocated raw pointer with a corresponding disposal function. + // The first template parameter should be a function pointer e.g. disposeWith(new int(0)). + + return Own(ptr, _::CUSTOM_DISPOSER_INSTANCE); +} +#endif +#endif + template Own> attachVal(T&& value, Attachments&&... attachments); // Returns an Own that takes ownership of `value` and `attachments`, and points to `value`. @@ -559,6 +729,21 @@ struct DisposableOwnedBundle final: public Disposer, public OwnedBundle { void disposeImpl(void* pointer) const override { delete this; } }; +template +class StaticDisposerAdapter final: public Disposer { + // Adapts a static disposer to be called dynamically. +public: + virtual void disposeImpl(void* pointer) const override { + StaticDisposer::dispose(reinterpret_cast(pointer)); + } + + static const StaticDisposerAdapter instance; +}; + +template +const StaticDisposerAdapter StaticDisposerAdapter::instance = + StaticDisposerAdapter(); + } // namespace _ (private) template @@ -591,6 +776,22 @@ Own> attachVal(T&& value, Attachments&&... attachments) { return Own>(&bundle->first, *bundle); } +template +template +inline Own::Own(Own&& other) noexcept + : ptr(cast(other.ptr)) { + if (_::castToVoid(other.ptr) != reinterpret_cast(other.ptr)) { + // Oh dangit, there's some sort of multiple inheritance going on and `StaticDisposerAdapter` + // won't actually work because it'll receive a pointer pointing to the top of the object, which + // isn't exactly the same as the `U*` pointer it wants. We have no choice but to allocate + // a dynamic disposer here. + disposer = new _::DisposableOwnedBundle>(kj::mv(other)); + } else { + disposer = &_::StaticDisposerAdapter::instance; + other.ptr = nullptr; + } +} + } // namespace kj KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/mutex-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/mutex-test.c++ index 32c0a5cf615..1b51e3e7c0d 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/mutex-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/mutex-test.c++ @@ -686,6 +686,7 @@ KJ_TEST("tracking blocking on mutex acquisition") { event.sigev_value.sival_ptr = &blockingInfo; KJ_SYSCALL(event._sigev_un._tid = gettid()); KJ_SYSCALL(timer_create(CLOCK_MONOTONIC, &event, &timer)); + KJ_DEFER(timer_delete(timer)); kj::Duration timeout = 50 * MILLISECONDS; struct itimerspec spec; @@ -742,6 +743,7 @@ KJ_TEST("tracking blocked on CondVar::wait") { event.sigev_value.sival_ptr = &blockingInfo; KJ_SYSCALL(event._sigev_un._tid = gettid()); KJ_SYSCALL(timer_create(CLOCK_MONOTONIC, &event, &timer)); + KJ_DEFER(timer_delete(timer)); kj::Duration timeout = 50 * MILLISECONDS; struct itimerspec spec; @@ -798,6 +800,7 @@ KJ_TEST("tracking blocked on Once::init") { event.sigev_value.sival_ptr = &blockingInfo; KJ_SYSCALL(event._sigev_un._tid = gettid()); KJ_SYSCALL(timer_create(CLOCK_MONOTONIC, &event, &timer)); + KJ_DEFER(timer_delete(timer)); Lazy once; MutexGuarded onceInitializing(false); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/mutex.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/mutex.c++ index 63e4f86793a..1ddd4921be4 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/mutex.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/mutex.c++ @@ -144,7 +144,7 @@ bool Mutex::checkPredicate(Waiter& waiter) { KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { result = waiter.predicate.check(); })) { - // Exception thown. + // Exception thrown. result = true; waiter.exception = kj::heap(kj::mv(*exception)); }; @@ -499,7 +499,7 @@ void Mutex::wait(Predicate& predicate, Maybe timeout, LockSourceLocati KJ_SYSCALL_HANDLE_ERRORS(syscall(SYS_futex, &waiter.futex, FUTEX_WAIT_BITSET_PRIVATE, 0, tsp, nullptr, FUTEX_BITSET_MATCH_ANY)) { case EAGAIN: - // Indicates that the futex was already non-zero by the time the kernal looked at it. + // Indicates that the futex was already non-zero by the time the kernel looked at it. // Not an error. break; case ETIMEDOUT: { diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/mutex.h b/libs/EXTERNAL/capnproto/c++/src/kj/mutex.h index e7b299e8116..619e7f95d45 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/mutex.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/mutex.h @@ -129,7 +129,7 @@ class Mutex { public: Mutex(); ~Mutex(); - KJ_DISALLOW_COPY(Mutex); + KJ_DISALLOW_COPY_AND_MOVE(Mutex); enum Exclusivity { EXCLUSIVE, @@ -247,7 +247,7 @@ class Mutex { kj::Maybe waitersHead = nullptr; kj::Maybe* waitersTail = &waitersHead; - // linked list of waitUntil()s; can only modify under lock + // linked list of waiters; can only modify under lock inline void addWaiter(Waiter& waiter); inline void removeWaiter(Waiter& waiter); @@ -268,7 +268,7 @@ class Once { Once(bool startInitialized = false); ~Once(); #endif - KJ_DISALLOW_COPY(Once); + KJ_DISALLOW_COPY_AND_MOVE(Once); class Initializer { public: @@ -446,12 +446,12 @@ class MutexGuarded { Maybe> lockExclusiveWithTimeout(Duration timeout, LockSourceLocationArg location = {}) const; - // Attempts to exclusively lock the object. If the timeout elapses before the lock is aquired, + // Attempts to exclusively lock the object. If the timeout elapses before the lock is acquired, // this returns null. Maybe> lockSharedWithTimeout(Duration timeout, LockSourceLocationArg location = {}) const; - // Attempts to lock the value for shared access. If the timeout elapses before the lock is aquired, + // Attempts to lock the value for shared access. If the timeout elapses before the lock is acquired, // this returns null. inline const T& getWithoutLock() const { return value; } @@ -758,7 +758,7 @@ using BlockedOnReason = OneOf blockedReason() noexcept; // Returns the information about the reason the current thread is blocked synchronously on KJ // lock primitives. Returns nullptr if the current thread is not currently blocked on such -// primitves. This is intended to be called from a signal handler to check whether the current +// primitives. This is intended to be called from a signal handler to check whether the current // thread is blocked. Outside of a signal handler there is little value to this function. In those // cases by definition the thread is not blocked. This includes the callable used as part of a // condition variable since that happens after the lock is acquired & the current thread is no diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/one-of-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/one-of-test.c++ index 7c74ca7e858..d7bec4d1084 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/one-of-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/one-of-test.c++ @@ -210,4 +210,25 @@ KJ_TEST("OneOf copy/move from alternative variants") { } } +template +struct T { + unsigned int n = N; +}; + +TEST(OneOf, MaxVariants) { + kj::OneOf< + T<1>, T<2>, T<3>, T<4>, T<5>, T<6>, T<7>, T<8>, T<9>, T<10>, + T<11>, T<12>, T<13>, T<14>, T<15>, T<16>, T<17>, T<18>, T<19>, T<20>, + T<21>, T<22>, T<23>, T<24>, T<25>, T<26>, T<27>, T<28>, T<29>, T<30>, + T<31>, T<32>, T<33>, T<34>, T<35>, T<36>, T<37>, T<38>, T<39>, T<40>, + T<41>, T<42>, T<43>, T<44>, T<45>, T<46>, T<47>, T<48>, T<49>, T<50> + > v; + + v = T<1>(); + EXPECT_TRUE(v.is>()); + + v = T<50>(); + EXPECT_TRUE(v.is>()); +} + } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/one-of.h b/libs/EXTERNAL/capnproto/c++/src/kj/one-of.h index cbed3916b87..51d4220a457 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/one-of.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/one-of.h @@ -98,6 +98,191 @@ enum class Variants20 { _variant0, _variant1, _variant2, _variant3, _variant4, _ _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, _variant19 }; +enum class Variants21 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20 }; +enum class Variants22 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21 }; +enum class Variants23 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22 }; +enum class Variants24 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23 }; +enum class Variants25 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24 }; +enum class Variants26 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25 }; +enum class Variants27 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26 }; +enum class Variants28 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27 }; +enum class Variants29 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28 }; +enum class Variants30 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29 }; +enum class Variants31 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30 }; +enum class Variants32 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31 }; +enum class Variants33 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32 }; +enum class Variants34 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33 }; +enum class Variants35 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34 }; +enum class Variants36 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35 }; +enum class Variants37 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36 }; +enum class Variants38 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37 }; +enum class Variants39 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38 }; +enum class Variants40 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39 }; +enum class Variants41 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40 }; +enum class Variants42 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41 }; +enum class Variants43 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42 }; +enum class Variants44 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43 }; +enum class Variants45 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44 }; +enum class Variants46 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45 }; +enum class Variants47 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45, _variant46 }; +enum class Variants48 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45, _variant46, _variant47 }; +enum class Variants49 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45, _variant46, _variant47, _variant48 }; +enum class Variants50 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45, _variant46, _variant47, _variant48, + _variant49 }; template struct Variants_; template <> struct Variants_<0> { typedef Variants0 Type; }; @@ -121,6 +306,36 @@ template <> struct Variants_<17> { typedef Variants17 Type; }; template <> struct Variants_<18> { typedef Variants18 Type; }; template <> struct Variants_<19> { typedef Variants19 Type; }; template <> struct Variants_<20> { typedef Variants20 Type; }; +template <> struct Variants_<21> { typedef Variants21 Type; }; +template <> struct Variants_<22> { typedef Variants22 Type; }; +template <> struct Variants_<23> { typedef Variants23 Type; }; +template <> struct Variants_<24> { typedef Variants24 Type; }; +template <> struct Variants_<25> { typedef Variants25 Type; }; +template <> struct Variants_<26> { typedef Variants26 Type; }; +template <> struct Variants_<27> { typedef Variants27 Type; }; +template <> struct Variants_<28> { typedef Variants28 Type; }; +template <> struct Variants_<29> { typedef Variants29 Type; }; +template <> struct Variants_<30> { typedef Variants30 Type; }; +template <> struct Variants_<31> { typedef Variants31 Type; }; +template <> struct Variants_<32> { typedef Variants32 Type; }; +template <> struct Variants_<33> { typedef Variants33 Type; }; +template <> struct Variants_<34> { typedef Variants34 Type; }; +template <> struct Variants_<35> { typedef Variants35 Type; }; +template <> struct Variants_<36> { typedef Variants36 Type; }; +template <> struct Variants_<37> { typedef Variants37 Type; }; +template <> struct Variants_<38> { typedef Variants38 Type; }; +template <> struct Variants_<39> { typedef Variants39 Type; }; +template <> struct Variants_<40> { typedef Variants40 Type; }; +template <> struct Variants_<41> { typedef Variants41 Type; }; +template <> struct Variants_<42> { typedef Variants42 Type; }; +template <> struct Variants_<43> { typedef Variants43 Type; }; +template <> struct Variants_<44> { typedef Variants44 Type; }; +template <> struct Variants_<45> { typedef Variants45 Type; }; +template <> struct Variants_<46> { typedef Variants46 Type; }; +template <> struct Variants_<47> { typedef Variants47 Type; }; +template <> struct Variants_<48> { typedef Variants48 Type; }; +template <> struct Variants_<49> { typedef Variants49 Type; }; +template <> struct Variants_<50> { typedef Variants50 Type; }; template using Variants = typename Variants_::Type; @@ -389,7 +604,7 @@ void OneOf::allHandled() { KJ_UNREACHABLE; } -#if __cplusplus > 201402L +#if KJ_CPP_STD > 201402L #define KJ_SWITCH_ONEOF(value) \ switch (auto _kj_switch_subject = (value)._switchSubject(); _kj_switch_subject->which()) #else diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/parse/common.h b/libs/EXTERNAL/capnproto/c++/src/kj/parse/common.h index 6d2653f907d..cfb97299caf 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/parse/common.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/parse/common.h @@ -68,7 +68,7 @@ class IteratorInput { parent->best = kj::max(kj::max(pos, best), parent->best); } } - KJ_DISALLOW_COPY(IteratorInput); + KJ_DISALLOW_COPY_AND_MOVE(IteratorInput); void advanceParent() { parent->pos = pos; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/refcount-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/refcount-test.c++ index 81f5e5fb374..c3b39d45770 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/refcount-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/refcount-test.c++ @@ -57,4 +57,81 @@ TEST(Refcount, Basic) { #endif } +struct SetTrueInDestructor2 { + // Like above but doesn't inherit Refcounted. + + SetTrueInDestructor2(bool* ptr): ptr(ptr) {} + ~SetTrueInDestructor2() { *ptr = true; } + + bool* ptr; +}; + +KJ_TEST("RefcountedWrapper") { + { + bool b = false; + Own> w = refcountedWrapper(&b); + KJ_EXPECT(!b); + + Own ref1 = w->addWrappedRef(); + Own ref2 = w->addWrappedRef(); + + KJ_EXPECT(ref1.get() == &w->getWrapped()); + KJ_EXPECT(ref1.get() == ref2.get()); + + KJ_EXPECT(!b); + + w = nullptr; + ref1 = nullptr; + + KJ_EXPECT(!b); + + ref2 = nullptr; + + KJ_EXPECT(b); + } + + // Wrap Own. + { + bool b = false; + Own>> w = + refcountedWrapper(kj::heap(&b)); + KJ_EXPECT(!b); + + Own ref1 = w->addWrappedRef(); + Own ref2 = w->addWrappedRef(); + + KJ_EXPECT(ref1.get() == &w->getWrapped()); + KJ_EXPECT(ref1.get() == ref2.get()); + + KJ_EXPECT(!b); + + w = nullptr; + ref1 = nullptr; + + KJ_EXPECT(!b); + + ref2 = nullptr; + + KJ_EXPECT(b); + } + + // Try wrapping an `int` to really demonstrate the wrapped type can be anything. + { + Own> w = refcountedWrapper(123); + int* ptr = &w->getWrapped(); + KJ_EXPECT(*ptr == 123); + + Own ref1 = w->addWrappedRef(); + Own ref2 = w->addWrappedRef(); + + KJ_EXPECT(ref1.get() == ptr); + KJ_EXPECT(ref2.get() == ptr); + + w = nullptr; + ref1 = nullptr; + + KJ_EXPECT(*ref2 == 123); + } +} + } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/refcount.h b/libs/EXTERNAL/capnproto/c++/src/kj/refcount.h index 51fd6dc79b8..03b5234d8d7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/refcount.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/refcount.h @@ -67,7 +67,7 @@ class Refcounted: private Disposer { public: Refcounted() = default; virtual ~Refcounted() noexcept(false); - KJ_DISALLOW_COPY(Refcounted); + KJ_DISALLOW_COPY_AND_MOVE(Refcounted); inline bool isShared() const { return refcount > 1; } // Check if there are multiple references to this object. This is sometimes useful for deciding @@ -85,6 +85,9 @@ class Refcounted: private Disposer { friend Own addRef(T& object); template friend Own refcounted(Params&&... params); + + template + friend class RefcountedWrapper; }; template @@ -112,6 +115,59 @@ Own Refcounted::addRefInternal(T* object) { return Own(object, *refcounted); } +template +class RefcountedWrapper: public Refcounted { + // Adds refcounting as a wrapper around an existing type, allowing you to construct references + // with type Own that appears to point directly to the underlying object. + +public: + template + RefcountedWrapper(Params&&... params): wrapped(kj::fwd(params)...) {} + + T& getWrapped() { return wrapped; } + const T& getWrapped() const { return wrapped; } + + Own addWrappedRef() { + // Return an owned reference to the wrapped object that is backed by a refcount. + ++refcount; + return Own(&wrapped, *this); + } + +private: + T wrapped; +}; + +template +class RefcountedWrapper>: public Refcounted { + // Specialization for when the wrapped type is itself Own. We don't want this to result in + // Own>. + +public: + RefcountedWrapper(Own wrapped): wrapped(kj::mv(wrapped)) {} + + T& getWrapped() { return *wrapped; } + const T& getWrapped() const { return *wrapped; } + + Own addWrappedRef() { + // Return an owned reference to the wrapped object that is backed by a refcount. + ++refcount; + return Own(wrapped.get(), *this); + } + +private: + Own wrapped; +}; + +template +Own> refcountedWrapper(Params&&... params) { + return refcounted>(kj::fwd(params)...); +} + +template +Own>> refcountedWrapper(Own&& wrapped) { + return refcounted>>(kj::mv(wrapped)); +} + // ======================================================================================= // Atomic (thread-safe) refcounting // @@ -129,7 +185,7 @@ class AtomicRefcounted: private kj::Disposer { public: AtomicRefcounted() = default; virtual ~AtomicRefcounted() noexcept(false); - KJ_DISALLOW_COPY(AtomicRefcounted); + KJ_DISALLOW_COPY_AND_MOVE(AtomicRefcounted); inline bool isShared() const { #if _MSC_VER && !defined(__clang__) diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/source-location.h b/libs/EXTERNAL/capnproto/c++/src/kj/source-location.h index ebcd4d3f120..40c53e8b2e0 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/source-location.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/source-location.h @@ -23,6 +23,8 @@ #include "string.h" +KJ_BEGIN_HEADER + // GCC does not implement __builtin_COLUMN() as that's non-standard but MSVC & clang do. // MSVC does as of version https://github.com/microsoft/STL/issues/54) but there's currently not any // pressing need for this for MSVC & writing the write compiler version check is annoying. @@ -39,7 +41,7 @@ #define KJ_CALLER_COLUMN() 0 #endif -#if __cplusplus > 201703L +#if KJ_CPP_STD > 201703L #define KJ_COMPILER_SUPPORTS_SOURCE_LOCATION 1 #elif defined(__has_builtin) // Clang 9 added these builtins: https://releases.llvm.org/9.0.0/tools/clang/docs/LanguageExtensions.html @@ -105,3 +107,5 @@ KJ_UNUSED static kj::String KJ_STRINGIFY(const NoopSourceLocation& l) { return kj::String(); } } // namespace kj + +KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/string-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/string-test.c++ index b461a4bb3fc..6109b7bbabf 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/string-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/string-test.c++ @@ -24,6 +24,7 @@ #include #include "vector.h" #include +#include namespace kj { namespace _ { // private @@ -164,6 +165,91 @@ TEST(String, parseAs) { EXPECT_EQ(heapString("1").parseAs(), 1); } +TEST(String, tryParseAs) { + KJ_EXPECT(StringPtr("0").tryParseAs() == 0.0); + KJ_EXPECT(StringPtr("0").tryParseAs() == 0.0); + KJ_EXPECT(StringPtr("0.0").tryParseAs() == 0.0); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1.0); + KJ_EXPECT(StringPtr("1.0").tryParseAs() == 1.0); + KJ_EXPECT(StringPtr("1e100").tryParseAs() == 1e100); + KJ_EXPECT(StringPtr("inf").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("infinity").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("INF").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("INFINITY").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("1e100000").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("-inf").tryParseAs() == -inf()); + KJ_EXPECT(StringPtr("-infinity").tryParseAs() == -inf()); + KJ_EXPECT(StringPtr("-INF").tryParseAs() == -inf()); + KJ_EXPECT(StringPtr("-INFINITY").tryParseAs() == -inf()); + KJ_EXPECT(StringPtr("-1e100000").tryParseAs() == -inf()); + KJ_EXPECT(isNaN(StringPtr("nan").tryParseAs().orDefault(0.0)) == true); + KJ_EXPECT(isNaN(StringPtr("NAN").tryParseAs().orDefault(0.0)) == true); + KJ_EXPECT(isNaN(StringPtr("NaN").tryParseAs().orDefault(0.0)) == true); + KJ_EXPECT(StringPtr("").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("1a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("+-1").tryParseAs() == nullptr); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1.0); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("9223372036854775807").tryParseAs() == 9223372036854775807LL); + KJ_EXPECT(StringPtr("-9223372036854775808").tryParseAs() == -9223372036854775808ULL); + KJ_EXPECT(StringPtr("9223372036854775808").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("-9223372036854775809").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("1a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("+-1").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("010").tryParseAs() == 10); + KJ_EXPECT(StringPtr("0010").tryParseAs() == 10); + KJ_EXPECT(StringPtr("0x10").tryParseAs() == 16); + KJ_EXPECT(StringPtr("0X10").tryParseAs() == 16); + KJ_EXPECT(StringPtr("-010").tryParseAs() == -10); + KJ_EXPECT(StringPtr("-0x10").tryParseAs() == -16); + KJ_EXPECT(StringPtr("-0X10").tryParseAs() == -16); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("0").tryParseAs() == 0); + KJ_EXPECT(StringPtr("18446744073709551615").tryParseAs() == 18446744073709551615ULL); + KJ_EXPECT(StringPtr("-1").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("18446744073709551616").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("1a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("+-1").tryParseAs() == nullptr); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("2147483647").tryParseAs() == 2147483647); + KJ_EXPECT(StringPtr("-2147483648").tryParseAs() == -2147483648); + KJ_EXPECT(StringPtr("2147483648").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("-2147483649").tryParseAs() == nullptr); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("0").tryParseAs() == 0U); + KJ_EXPECT(StringPtr("4294967295").tryParseAs() == 4294967295U); + KJ_EXPECT(StringPtr("-1").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("4294967296").tryParseAs() == nullptr); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + + KJ_EXPECT(heapString("1").tryParseAs() == 1); +} + #if KJ_COMPILER_SUPPORTS_STL_STRING_INTEROP TEST(String, StlInterop) { std::string foo = "foo"; @@ -246,7 +332,7 @@ KJ_TEST("ArrayPtr == StringPtr") { ArrayPtr a = s; KJ_EXPECT(a == s); -#if __cplusplus >= 202000L +#if KJ_CPP_STD >= 202000L KJ_EXPECT(s == a); #endif } @@ -282,7 +368,8 @@ KJ_TEST("float stringification and parsing is not locale-dependent") { KJ_EXPECT("1.5"_kj.parseAs() == 1.5); if (setlocale(LC_NUMERIC, "es_ES") == nullptr && - setlocale(LC_NUMERIC, "es_ES.utf8") == nullptr) { + setlocale(LC_NUMERIC, "es_ES.utf8") == nullptr && + setlocale(LC_NUMERIC, "es_ES.UTF-8") == nullptr) { // Some systems may not have the desired locale available. KJ_LOG(WARNING, "Couldn't set locale to es_ES. Skipping this test."); } else { @@ -292,6 +379,60 @@ KJ_TEST("float stringification and parsing is not locale-dependent") { KJ_EXPECT("1.5"_kj.parseAs() == 1.5); } } + +KJ_TEST("ConstString literal operator") { + kj::ConstString theString = "it's a const string!"_kjc; + KJ_EXPECT(theString == "it's a const string!"); +} + +KJ_TEST("ConstString promotion") { + kj::StringPtr theString = "it's a const string!"; + kj::ConstString constString = theString.attach(); + KJ_EXPECT(constString == "it's a const string!"); +} + +struct DestructionOrderRecorder { + DestructionOrderRecorder(uint& counter, uint& recordTo) + : counter(counter), recordTo(recordTo) {} + ~DestructionOrderRecorder() { + recordTo = ++counter; + } + + uint& counter; + uint& recordTo; +}; + +KJ_TEST("ConstString attachment lifetimes") { + uint counter = 0; + uint destroyed1 = 0; + uint destroyed2 = 0; + uint destroyed3 = 0; + + auto obj1 = kj::heap(counter, destroyed1); + auto obj2 = kj::heap(counter, destroyed2); + auto obj3 = kj::heap(counter, destroyed3); + + StringPtr theString = "it's a string!"; + const char* ptr = theString.begin(); + + ConstString combined = theString.attach(kj::mv(obj1), kj::mv(obj2), kj::mv(obj3)); + + KJ_EXPECT(combined.begin() == ptr); + + KJ_EXPECT(obj1.get() == nullptr); + KJ_EXPECT(obj2.get() == nullptr); + KJ_EXPECT(obj3.get() == nullptr); + KJ_EXPECT(destroyed1 == 0); + KJ_EXPECT(destroyed2 == 0); + KJ_EXPECT(destroyed3 == 0); + + combined = nullptr; + + KJ_EXPECT(destroyed1 == 1, destroyed1); + KJ_EXPECT(destroyed2 == 2, destroyed2); + KJ_EXPECT(destroyed3 == 3, destroyed3); +} + } // namespace } // namespace _ (private) } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/string.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/string.c++ index 7dc5fc67ab4..cf2c5fcaa95 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/string.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/string.c++ @@ -29,11 +29,6 @@ namespace kj { -#if _MSC_VER && !defined(__clang__) -#pragma warning(disable: 4996) -// Warns that sprintf() is buffer-overrunny. We know that, it's cool. -#endif - namespace { bool isHex(const char *s) { if (*s == '-') s++; @@ -51,6 +46,17 @@ long long parseSigned(const StringPtr& s, long long min, long long max) { return value; } +Maybe tryParseSigned(const StringPtr& s, long long min, long long max) { + if (s == nullptr) { return nullptr; } // String does not contain valid number. + char *endPtr; + errno = 0; + auto value = strtoll(s.begin(), &endPtr, isHex(s.cStr()) ? 16 : 10); + if (endPtr != s.end() || errno == ERANGE || value < min || max < value) { + return nullptr; + } + return value; +} + unsigned long long parseUnsigned(const StringPtr& s, unsigned long long max) { KJ_REQUIRE(s != nullptr, "String does not contain valid number", s) { return 0; } char *endPtr; @@ -64,6 +70,15 @@ unsigned long long parseUnsigned(const StringPtr& s, unsigned long long max) { return value; } +Maybe tryParseUnsigned(const StringPtr& s, unsigned long long max) { + if (s == nullptr) { return nullptr; } // String does not contain valid number. + char *endPtr; + errno = 0; + auto value = strtoull(s.begin(), &endPtr, isHex(s.cStr()) ? 16 : 10); + if (endPtr != s.end() || errno == ERANGE || max < value || s[0] == '-') { return nullptr; } + return value; +} + template T parseInteger(const StringPtr& s) { if (static_cast(minValue) < 0) { @@ -76,6 +91,18 @@ T parseInteger(const StringPtr& s) { } } +template +Maybe tryParseInteger(const StringPtr& s) { + if (static_cast(minValue) < 0) { + long long min = static_cast(minValue); + long long max = static_cast(maxValue); + return static_cast>(tryParseSigned(s, min, max)); + } else { + unsigned long long max = static_cast(maxValue); + return static_cast>(tryParseUnsigned(s, max)); + } +} + } // namespace #define PARSE_AS_INTEGER(T) \ @@ -93,6 +120,21 @@ PARSE_AS_INTEGER(long long); PARSE_AS_INTEGER(unsigned long long); #undef PARSE_AS_INTEGER +#define TRY_PARSE_AS_INTEGER(T) \ + template <> Maybe StringPtr::tryParseAs() const { return tryParseInteger(*this); } +TRY_PARSE_AS_INTEGER(char); +TRY_PARSE_AS_INTEGER(signed char); +TRY_PARSE_AS_INTEGER(unsigned char); +TRY_PARSE_AS_INTEGER(short); +TRY_PARSE_AS_INTEGER(unsigned short); +TRY_PARSE_AS_INTEGER(int); +TRY_PARSE_AS_INTEGER(unsigned int); +TRY_PARSE_AS_INTEGER(long); +TRY_PARSE_AS_INTEGER(unsigned long); +TRY_PARSE_AS_INTEGER(long long); +TRY_PARSE_AS_INTEGER(unsigned long long); +#undef TRY_PARSE_AS_INTEGER + String heapString(size_t size) { char* buffer = _::HeapArrayDisposer::allocate(size + 1); buffer[size] = '\0'; @@ -478,7 +520,7 @@ kj::String LocalizeRadix(const char* input, const char* radix_pos) { // to divuldge the locale's radix character. No, localeconv() is NOT // thread-safe. char temp[16]; - int size = sprintf(temp, "%.1f", 1.5); + int size = snprintf(temp, sizeof(temp), "%.1f", 1.5); KJ_ASSERT(temp[0] == '1'); KJ_ASSERT(temp[size-1] == '5'); KJ_ASSERT(size <= 6); @@ -571,9 +613,26 @@ double parseDouble(const StringPtr& s) { return value; } +Maybe tryParseDouble(const StringPtr& s) { + if(s == nullptr) { return nullptr; } + char *endPtr; + errno = 0; + auto value = _::NoLocaleStrtod(s.begin(), &endPtr); + if (endPtr != s.end()) { return nullptr; } +#if _WIN32 || __CYGWIN__ || __BIONIC__ + if (isNaN(value)) { + return kj::nan(); + } +#endif + return value; +} + } // namespace _ (private) template <> double StringPtr::parseAs() const { return _::parseDouble(*this); } template <> float StringPtr::parseAs() const { return _::parseDouble(*this); } +template <> Maybe StringPtr::tryParseAs() const { return _::tryParseDouble(*this); } +template <> Maybe StringPtr::tryParseAs() const { return _::tryParseDouble(*this); } + } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/string.h b/libs/EXTERNAL/capnproto/c++/src/kj/string.h index 193442aad06..10b978875aa 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/string.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/string.h @@ -23,13 +23,16 @@ #include #include "array.h" +#include "kj/common.h" #include KJ_BEGIN_HEADER namespace kj { class StringPtr; + class LiteralStringConst; class String; + class ConstString; class StringTree; // string-tree.h } @@ -50,6 +53,8 @@ constexpr kj::StringPtr operator "" _kj(const char* str, size_t n); // string literal vs. one with _kj (assuming the compiler is able to optimize away strlen() on a // string literal). +constexpr kj::LiteralStringConst operator "" _kjc(const char* str, size_t n); + namespace kj { // Our STL string SFINAE trick does not work with GCC 4.7, but it works with Clang and GCC 4.8, so @@ -76,6 +81,7 @@ class StringPtr { inline StringPtr(const char* begin KJ_LIFETIMEBOUND, const char* end KJ_LIFETIMEBOUND): StringPtr(begin, end - begin) {} inline StringPtr(String&& value KJ_LIFETIMEBOUND) : StringPtr(value) {} inline StringPtr(const String& value KJ_LIFETIMEBOUND); + inline StringPtr(const ConstString& value KJ_LIFETIMEBOUND); StringPtr& operator=(String&& value) = delete; inline StringPtr& operator=(decltype(nullptr)) { content = ArrayPtr("", 1); @@ -92,15 +98,20 @@ class StringPtr { #endif #if KJ_COMPILER_SUPPORTS_STL_STRING_INTEROP - template ().c_str())> - inline StringPtr(const T& t KJ_LIFETIMEBOUND): StringPtr(t.c_str()) {} - // Allow implicit conversion from any class that has a c_str() method (namely, std::string). + template < + typename T, + typename = decltype(instance().c_str()), + typename = decltype(instance().size())> + inline StringPtr(const T& t KJ_LIFETIMEBOUND): StringPtr(t.c_str(), t.size()) {} + // Allow implicit conversion from any class that has a c_str() and a size() method (namely, std::string). // We use a template trick to detect std::string in order to avoid including the header for // those who don't want it. - - template ().c_str())> - inline operator T() const { return cStr(); } - // Allow implicit conversion to any class that has a c_str() method (namely, std::string). + template < + typename T, + typename = decltype(instance().c_str()), + typename = decltype(instance().size())> + inline operator T() const { return {cStr(), size()}; } + // Allow implicit conversion to any class that has a c_str() method and a size() method (namely, std::string). // We use a template trick to detect std::string in order to avoid including the header for // those who don't want it. #endif @@ -122,10 +133,14 @@ class StringPtr { inline constexpr const char* end() const { return content.end() - 1; } inline constexpr bool operator==(decltype(nullptr)) const { return content.size() <= 1; } +#if !__cpp_impl_three_way_comparison inline constexpr bool operator!=(decltype(nullptr)) const { return content.size() > 1; } +#endif inline bool operator==(const StringPtr& other) const; +#if !__cpp_impl_three_way_comparison inline bool operator!=(const StringPtr& other) const { return !(*this == other); } +#endif inline bool operator< (const StringPtr& other) const; inline bool operator> (const StringPtr& other) const { return other < *this; } inline bool operator<=(const StringPtr& other) const { return !(other < *this); } @@ -136,11 +151,11 @@ class StringPtr { // A string slice is only NUL-terminated if it is a suffix, so slice() has a one-parameter // version that assumes end = size(). - inline bool startsWith(const StringPtr& other) const; - inline bool endsWith(const StringPtr& other) const; + inline bool startsWith(const StringPtr& other) const { return asArray().startsWith(other);} + inline bool endsWith(const StringPtr& other) const { return asArray().endsWith(other); } - inline Maybe findFirst(char c) const; - inline Maybe findLast(char c) const; + inline Maybe findFirst(char c) const { return asArray().findFirst(c); } + inline Maybe findLast(char c) const { return asArray().findLast(c); } template T parseAs() const; @@ -149,13 +164,22 @@ class StringPtr { // Integer numbers prefixed by "0" are parsed in base 10 (unlike strtoi with base 0). // Overflowed integer numbers throw exception. // Overflowed floating numbers return inf. + template + Maybe tryParseAs() const; + // Same as parseAs, but rather than throwing an exception we return NULL. + + template + ConstString attach(Attachments&&... attachments) const KJ_WARN_UNUSED_RESULT; + ConstString attach() const KJ_WARN_UNUSED_RESULT; + // Like ArrayPtr::attach(), but instead promotes a StringPtr into a ConstString. Generally the + // attachment should be an object that somehow owns the String that the StringPtr is pointing at. private: inline explicit constexpr StringPtr(ArrayPtr content): content(content) {} + friend constexpr StringPtr (::operator "" _kj)(const char* str, size_t n); + friend class LiteralStringConst; ArrayPtr content; - - friend constexpr kj::StringPtr (::operator "" _kj)(const char* str, size_t n); friend class SourceLocation; }; @@ -178,6 +202,29 @@ template <> unsigned long long StringPtr::parseAs() const; template <> float StringPtr::parseAs() const; template <> double StringPtr::parseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; + +class LiteralStringConst: public StringPtr { +public: + inline operator ConstString() const; + +private: + inline explicit constexpr LiteralStringConst(ArrayPtr content): StringPtr(content) {} + friend constexpr LiteralStringConst (::operator "" _kjc)(const char* str, size_t n); +}; + // ======================================================================================= // String -- A NUL-terminated Array containing UTF-8 text. // @@ -231,14 +278,18 @@ class String { inline bool operator!=(decltype(nullptr)) const { return content.size() > 1; } inline bool operator==(const StringPtr& other) const { return StringPtr(*this) == other; } +#if !__cpp_impl_three_way_comparison inline bool operator!=(const StringPtr& other) const { return StringPtr(*this) != other; } +#endif inline bool operator< (const StringPtr& other) const { return StringPtr(*this) < other; } inline bool operator> (const StringPtr& other) const { return StringPtr(*this) > other; } inline bool operator<=(const StringPtr& other) const { return StringPtr(*this) <= other; } inline bool operator>=(const StringPtr& other) const { return StringPtr(*this) >= other; } inline bool operator==(const String& other) const { return StringPtr(*this) == StringPtr(other); } +#if !__cpp_impl_three_way_comparison inline bool operator!=(const String& other) const { return StringPtr(*this) != StringPtr(other); } +#endif inline bool operator< (const String& other) const { return StringPtr(*this) < StringPtr(other); } inline bool operator> (const String& other) const { return StringPtr(*this) > StringPtr(other); } inline bool operator<=(const String& other) const { return StringPtr(*this) <= StringPtr(other); } @@ -247,8 +298,17 @@ class String { // comparisons between two strings are ambiguous. (Clang turns this into a warning, // -Wambiguous-reversed-operator, due to the stupidity...) - inline bool startsWith(const StringPtr& other) const { return StringPtr(*this).startsWith(other);} - inline bool endsWith(const StringPtr& other) const { return StringPtr(*this).endsWith(other); } + inline bool operator==(const ConstString& other) const { return StringPtr(*this) == StringPtr(other); } +#if !__cpp_impl_three_way_comparison + inline bool operator!=(const ConstString& other) const { return StringPtr(*this) != StringPtr(other); } +#endif + inline bool operator< (const ConstString& other) const { return StringPtr(*this) < StringPtr(other); } + inline bool operator> (const ConstString& other) const { return StringPtr(*this) > StringPtr(other); } + inline bool operator<=(const ConstString& other) const { return StringPtr(*this) <= StringPtr(other); } + inline bool operator>=(const ConstString& other) const { return StringPtr(*this) >= StringPtr(other); } + + inline bool startsWith(const StringPtr& other) const { return asArray().startsWith(other);} + inline bool endsWith(const StringPtr& other) const { return asArray().endsWith(other); } inline StringPtr slice(size_t start) const KJ_LIFETIMEBOUND { return StringPtr(*this).slice(start); @@ -257,17 +317,120 @@ class String { return StringPtr(*this).slice(start, end); } - inline Maybe findFirst(char c) const { return StringPtr(*this).findFirst(c); } - inline Maybe findLast(char c) const { return StringPtr(*this).findLast(c); } + inline Maybe findFirst(char c) const { return asArray().findFirst(c); } + inline Maybe findLast(char c) const { return asArray().findLast(c); } template T parseAs() const { return StringPtr(*this).parseAs(); } // Parse as number + template + Maybe tryParseAs() const { return StringPtr(*this).tryParseAs(); } + private: Array content; }; +// ======================================================================================= +// ConstString -- Same as String, but the backing buffer is const. +// +// This has the useful property that it can reference a string literal without allocating +// a copy. Any String can also convert (by move) to ConstString, transferring ownership of +// the buffer. + +class ConstString { +public: + ConstString() = default; + inline ConstString(decltype(nullptr)): content(nullptr) {} + inline ConstString(const char* value, size_t size, const ArrayDisposer& disposer); + // Does not copy. `size` does not include NUL terminator, but `value` must be NUL-terminated. + inline explicit ConstString(Array buffer); + // Does not copy. Requires `buffer` ends with `\0`. + inline explicit ConstString(String&& string): content(string.releaseArray()) {} + // Does not copy. Ownership is transfered. + + inline operator ArrayPtr() const KJ_LIFETIMEBOUND; + inline ArrayPtr asArray() const KJ_LIFETIMEBOUND; + inline ArrayPtr asBytes() const KJ_LIFETIMEBOUND { return asArray().asBytes(); } + // Result does not include NUL terminator. + + inline StringPtr asPtr() const KJ_LIFETIMEBOUND { + // Convenience operator to return a StringPtr. + return StringPtr{*this}; + } + + inline Array releaseArray() { return kj::mv(content); } + // Disowns the backing array (which includes the NUL terminator) and returns it. The ConstString value + // is clobbered (as if moved away). + + inline const char* cStr() const KJ_LIFETIMEBOUND; + + inline size_t size() const; + // Result does not include NUL terminator. + + inline char operator[](size_t index) const; + inline char& operator[](size_t index) KJ_LIFETIMEBOUND; + + inline const char* begin() const KJ_LIFETIMEBOUND; + inline const char* end() const KJ_LIFETIMEBOUND; + + inline bool operator==(decltype(nullptr)) const { return content.size() <= 1; } + inline bool operator!=(decltype(nullptr)) const { return content.size() > 1; } + + inline bool operator==(const StringPtr& other) const { return StringPtr(*this) == other; } +#if !__cpp_impl_three_way_comparison + inline bool operator!=(const StringPtr& other) const { return StringPtr(*this) != other; } +#endif + inline bool operator< (const StringPtr& other) const { return StringPtr(*this) < other; } + inline bool operator> (const StringPtr& other) const { return StringPtr(*this) > other; } + inline bool operator<=(const StringPtr& other) const { return StringPtr(*this) <= other; } + inline bool operator>=(const StringPtr& other) const { return StringPtr(*this) >= other; } + + inline bool operator==(const String& other) const { return StringPtr(*this) == StringPtr(other); } +#if !__cpp_impl_three_way_comparison + inline bool operator!=(const String& other) const { return StringPtr(*this) != StringPtr(other); } +#endif + inline bool operator< (const String& other) const { return StringPtr(*this) < StringPtr(other); } + inline bool operator> (const String& other) const { return StringPtr(*this) > StringPtr(other); } + inline bool operator<=(const String& other) const { return StringPtr(*this) <= StringPtr(other); } + inline bool operator>=(const String& other) const { return StringPtr(*this) >= StringPtr(other); } + + inline bool operator==(const ConstString& other) const { return StringPtr(*this) == StringPtr(other); } +#if !__cpp_impl_three_way_comparison + inline bool operator!=(const ConstString& other) const { return StringPtr(*this) != StringPtr(other); } +#endif + inline bool operator< (const ConstString& other) const { return StringPtr(*this) < StringPtr(other); } + inline bool operator> (const ConstString& other) const { return StringPtr(*this) > StringPtr(other); } + inline bool operator<=(const ConstString& other) const { return StringPtr(*this) <= StringPtr(other); } + inline bool operator>=(const ConstString& other) const { return StringPtr(*this) >= StringPtr(other); } + // Note that if we don't overload for `const ConstString&` specifically, then C++20 will decide that + // comparisons between two strings are ambiguous. (Clang turns this into a warning, + // -Wambiguous-reversed-operator, due to the stupidity...) + + inline bool startsWith(const StringPtr& other) const { return asArray().startsWith(other);} + inline bool endsWith(const StringPtr& other) const { return asArray().endsWith(other); } + + inline StringPtr slice(size_t start) const KJ_LIFETIMEBOUND { + return StringPtr(*this).slice(start); + } + inline ArrayPtr slice(size_t start, size_t end) const KJ_LIFETIMEBOUND { + return StringPtr(*this).slice(start, end); + } + + inline Maybe findFirst(char c) const { return asArray().findFirst(c); } + inline Maybe findLast(char c) const { return asArray().findLast(c); } + + template + T parseAs() const { return StringPtr(*this).parseAs(); } + // Parse as number + + template + Maybe tryParseAs() const { return StringPtr(*this).tryParseAs(); } + +private: + Array content; +}; + #if !__cpp_impl_three_way_comparison inline bool operator==(const char* a, const String& b) { return b == a; } inline bool operator!=(const char* a, const String& b) { return b != a; } @@ -397,6 +560,7 @@ struct Stringifier { return s.asArray(); } inline ArrayPtr operator*(const StringPtr& s) const { return s.asArray(); } + inline ArrayPtr operator*(const ConstString& s) const { return s.asArray(); } inline Range operator*(const Range& r) const { return r; } inline Repeat operator*(const Repeat& r) const { return r; } @@ -544,6 +708,7 @@ inline _::Delimited> operator*(const _::Stringifier&, const Ar // Inline implementation details. inline StringPtr::StringPtr(const String& value): content(value.cStr(), value.size() + 1) {} +inline StringPtr::StringPtr(const ConstString& value): content(value.cStr(), value.size() + 1) {} inline constexpr StringPtr::operator ArrayPtr() const { return ArrayPtr(content.begin(), content.size() - 1); @@ -572,31 +737,18 @@ inline ArrayPtr StringPtr::slice(size_t start, size_t end) const { return content.slice(start, end); } -inline bool StringPtr::startsWith(const StringPtr& other) const { - return other.content.size() <= content.size() && - memcmp(content.begin(), other.content.begin(), other.size()) == 0; -} -inline bool StringPtr::endsWith(const StringPtr& other) const { - return other.content.size() <= content.size() && - memcmp(end() - other.size(), other.content.begin(), other.size()) == 0; +inline LiteralStringConst::operator ConstString() const { + return ConstString(begin(), size(), NullArrayDisposer::instance); } -inline Maybe StringPtr::findFirst(char c) const { - const char* pos = reinterpret_cast(memchr(content.begin(), c, size())); - if (pos == nullptr) { - return nullptr; - } else { - return pos - content.begin(); - } +inline ConstString StringPtr::attach() const { + // This is meant as a roundabout way to make a ConstString from a StringPtr + return ConstString(begin(), size(), NullArrayDisposer::instance); } -inline Maybe StringPtr::findLast(char c) const { - for (size_t i = size(); i > 0; --i) { - if (content[i-1] == c) { - return i-1; - } - } - return nullptr; +template +inline ConstString StringPtr::attach(Attachments&&... attachments) const { + return ConstString { content.attach(kj::fwd(attachments)...) }; } inline String::operator ArrayPtr() { @@ -605,6 +757,9 @@ inline String::operator ArrayPtr() { inline String::operator ArrayPtr() const { return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); } +inline ConstString::operator ArrayPtr() const { + return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); +} inline ArrayPtr String::asArray() { return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); @@ -612,27 +767,42 @@ inline ArrayPtr String::asArray() { inline ArrayPtr String::asArray() const { return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); } +inline ArrayPtr ConstString::asArray() const { + return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); +} inline const char* String::cStr() const { return content == nullptr ? "" : content.begin(); } +inline const char* ConstString::cStr() const { return content == nullptr ? "" : content.begin(); } inline size_t String::size() const { return content == nullptr ? 0 : content.size() - 1; } +inline size_t ConstString::size() const { return content == nullptr ? 0 : content.size() - 1; } inline char String::operator[](size_t index) const { return content[index]; } inline char& String::operator[](size_t index) { return content[index]; } +inline char ConstString::operator[](size_t index) const { return content[index]; } inline char* String::begin() { return content == nullptr ? nullptr : content.begin(); } inline char* String::end() { return content == nullptr ? nullptr : content.end() - 1; } inline const char* String::begin() const { return content == nullptr ? nullptr : content.begin(); } inline const char* String::end() const { return content == nullptr ? nullptr : content.end() - 1; } +inline const char* ConstString::begin() const { return content == nullptr ? nullptr : content.begin(); } +inline const char* ConstString::end() const { return content == nullptr ? nullptr : content.end() - 1; } inline String::String(char* value, size_t size, const ArrayDisposer& disposer) : content(value, size + 1, disposer) { KJ_IREQUIRE(value[size] == '\0', "String must be NUL-terminated."); } +inline ConstString::ConstString(const char* value, size_t size, const ArrayDisposer& disposer) + : content(value, size + 1, disposer) { + KJ_IREQUIRE(value[size] == '\0', "String must be NUL-terminated."); +} inline String::String(Array buffer): content(kj::mv(buffer)) { KJ_IREQUIRE(content.size() > 0 && content.back() == '\0', "String must be NUL-terminated."); } +inline ConstString::ConstString(Array buffer): content(kj::mv(buffer)) { + KJ_IREQUIRE(content.size() > 0 && content.back() == '\0', "String must be NUL-terminated."); +} inline String heapString(const char* value) { return heapString(value, strlen(value)); @@ -758,4 +928,8 @@ constexpr kj::StringPtr operator "" _kj(const char* str, size_t n) { return kj::StringPtr(kj::ArrayPtr(str, n + 1)); }; +constexpr kj::LiteralStringConst operator "" _kjc(const char* str, size_t n) { + return kj::LiteralStringConst(kj::ArrayPtr(str, n + 1)); +}; + KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/table-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/table-test.c++ index 202ab9b6798..708843a2bd7 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/table-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/table-test.c++ @@ -24,6 +24,8 @@ #include #include #include "hash.h" +#include "time.h" +#include namespace kj { namespace _ { @@ -45,7 +47,7 @@ KJ_TEST("_::tryReserveSize() works") { { Vector vec; tryReserveSize(vec, "foo"_kj); - KJ_EXPECT(vec.capacity() == 3); + KJ_EXPECT(vec.capacity() == 4); // Vectors always grow by powers of two. } { Vector vec; @@ -907,6 +909,90 @@ KJ_TEST("large tree table") { } } +KJ_TEST("TreeIndex fuzz test") { + // A test which randomly modifies a TreeIndex to try to discover buggy state changes. + + uint seed = (kj::systemPreciseCalendarClock().now() - kj::UNIX_EPOCH) / kj::NANOSECONDS; + KJ_CONTEXT(seed); // print the seed if the test fails + srand(seed); + + Table> table; + + auto randomInsert = [&]() { + table.upsert(rand(), [](auto&&, auto&&) {}); + }; + auto randomErase = [&]() { + if (table.size() > 0) { + auto& row = table.begin()[rand() % table.size()]; + table.erase(row); + } + }; + auto randomLookup = [&]() { + if (table.size() > 0) { + auto& row = table.begin()[rand() % table.size()]; + auto& found = KJ_ASSERT_NONNULL(table.find(row)); + KJ_ASSERT(&found == &row); + } + }; + + // First pass: focus on insertions, aim to do 2x as many insertions as deletions. + for (auto i KJ_UNUSED: kj::zeroTo(1000)) { + switch (rand() % 4) { + case 0: + case 1: + randomInsert(); + break; + case 2: + randomErase(); + break; + case 3: + randomLookup(); + break; + } + + table.verify(); + } + + // Second pass: focus on deletions, aim to do 2x as many deletions as insertions. + for (auto i KJ_UNUSED: kj::zeroTo(1000)) { + switch (rand() % 4) { + case 0: + randomInsert(); + break; + case 1: + case 2: + randomErase(); + break; + case 3: + randomLookup(); + break; + } + + table.verify(); + } +} + +KJ_TEST("TreeIndex clear() leaves tree in valid state") { + // A test which ensures that calling clear() does not break the internal state of a TreeIndex. + // It used to be the case that clearing a non-empty tree would leave it thinking that it had room + // for one more node than it really did, causing it to write and read beyond the end of its + // internal array of nodes. + Table> table; + + // Insert at least one value to allocate an initial set of tree nodes. + table.upsert(1, [](auto&&, auto&&) {}); + KJ_EXPECT(table.find(1) != nullptr); + table.clear(); + + // Insert enough values to force writes/reads beyond the end of the tree's internal node array. + for (uint i = 0; i < 29; ++i) { + table.upsert(i, [](auto&&, auto&&) {}); + } + for (uint i = 0; i < 29; ++i) { + KJ_EXPECT(table.find(i) != nullptr); + } +} + KJ_TEST("benchmark: kj::Table") { constexpr uint SOME_PRIME = BIG_PRIME; constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; @@ -1231,6 +1317,94 @@ KJ_TEST("insertion order index is movable") { KJ_EXPECT(iter == range.end()); } +// ======================================================================================= +// Test bug where insertion failure on a later index in the table would not be rolled back +// correctly if a previous index was TreeIndex. + +class StringLengthCompare { + // Considers two strings equal if they have the same length. +public: + inline size_t keyForRow(StringPtr entry) const { + return entry.size(); + } + + inline bool matches(StringPtr e, size_t key) const { + return e.size() == key; + } + + inline bool isBefore(StringPtr e, size_t key) const { + return e.size() < key; + } + + uint hashCode(size_t size) const { + return size; + } +}; + +KJ_TEST("HashIndex rollback on insertion failure") { + // Test that when an insertion produces a duplicate on a later index, changes to previous indexes + // are properly rolled back. + + Table, HashIndex> table; + table.insert("a"_kj); + table.insert("ab"_kj); + table.insert("abc"_kj); + + { + // We use upsert() so that we don't throw an exception from the duplicate, but this exercises + // the same logic as a duplicate insert() other than throwing. + kj::StringPtr& found = table.upsert("xyz"_kj, [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(existing == "abc"); + KJ_EXPECT(param == "xyz"); + }); + KJ_EXPECT(found == "abc"); + + table.erase(found); + } + + table.insert("xyz"_kj); + + { + kj::StringPtr& found = table.upsert("tuv"_kj, [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(existing == "xyz"); + KJ_EXPECT(param == "tuv"); + }); + KJ_EXPECT(found == "xyz"); + } +} + +KJ_TEST("TreeIndex rollback on insertion failure") { + // Test that when an insertion produces a duplicate on a later index, changes to previous indexes + // are properly rolled back. + + Table, TreeIndex> table; + table.insert("a"_kj); + table.insert("ab"_kj); + table.insert("abc"_kj); + + { + // We use upsert() so that we don't throw an exception from the duplicate, but this exercises + // the same logic as a duplicate insert() other than throwing. + kj::StringPtr& found = table.upsert("xyz"_kj, [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(existing == "abc"); + KJ_EXPECT(param == "xyz"); + }); + KJ_EXPECT(found == "abc"); + + table.erase(found); + } + + table.insert("xyz"_kj); + + { + kj::StringPtr& found = table.upsert("tuv"_kj, [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(existing == "xyz"); + KJ_EXPECT(param == "tuv"); + }); + KJ_EXPECT(found == "xyz"); + } +} + } // namespace } // namespace _ } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/table.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/table.c++ index 62cfa6e4e34..4b0e028524c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/table.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/table.c++ @@ -23,6 +23,11 @@ #include "debug.h" #include +#if KJ_DEBUG_TABLE_IMPL +#undef KJ_DASSERT +#define KJ_DASSERT KJ_ASSERT +#endif + namespace kj { namespace _ { @@ -257,28 +262,39 @@ size_t BTreeImpl::verifyNode(size_t size, FunctionParam& f, auto n = parent.keyCount(); size_t total = 0; for (auto i: kj::zeroTo(n)) { - KJ_ASSERT(*parent.keys[i] < size); + KJ_ASSERT(*parent.keys[i] < size, n, i); total += verifyNode(size, f, parent.children[i], height - 1, parent.keys[i]); - KJ_ASSERT(i + 1 == n || f(*parent.keys[i], *parent.keys[i + 1])); + if (i > 0) { + KJ_ASSERT(f(*parent.keys[i - 1], *parent.keys[i]), + n, i, parent.keys[i - 1], parent.keys[i]); + } } total += verifyNode(size, f, parent.children[n], height - 1, maxRow); - KJ_ASSERT(maxRow == nullptr || f(*parent.keys[n-1], *maxRow)); + if (maxRow != nullptr) { + KJ_ASSERT(f(*parent.keys[n-1], *maxRow), n, parent.keys[n-1], maxRow); + } return total; } else { auto& leaf = tree[pos].leaf; auto n = leaf.size(); for (auto i: kj::zeroTo(n)) { - KJ_ASSERT(*leaf.rows[i] < size); - if (i + 1 < n) { - KJ_ASSERT(f(*leaf.rows[i], *leaf.rows[i + 1])); - } else { - KJ_ASSERT(maxRow == nullptr || leaf.rows[n-1] == maxRow); + KJ_ASSERT(*leaf.rows[i] < size, n, i); + if (i > 0) { + KJ_ASSERT(f(*leaf.rows[i - 1], *leaf.rows[i]), + n, i, leaf.rows[i - 1], leaf.rows[i]); } } + if (maxRow != nullptr) { + KJ_ASSERT(leaf.rows[n-1] == maxRow, n); + } return n; } } +kj::String BTreeImpl::MaybeUint::toString() const { + return i == 0 ? kj::str("(null)") : kj::str(i - 1); +} + void BTreeImpl::logInconsistency() const { KJ_LOG(ERROR, "BTreeIndex detected tree state inconsistency. This can happen if you create a kj::Table " @@ -319,7 +335,7 @@ void BTreeImpl::clear() { azero(tree, treeCapacity); height = 0; freelistHead = 1; - freelistSize = treeCapacity; + freelistSize = treeCapacity - 1; // subtract one to account for the root node beginLeaf = 0; endLeaf = 0; } @@ -651,7 +667,7 @@ void BTreeImpl::renumber(uint oldRow, uint newRow, const SearchKey& searchKey) { auto& node = tree[pos].parent; uint indexInParent = searchKey.search(node); pos = node.children[indexInParent]; - if (node.keys[indexInParent] == oldRow) { + if (indexInParent < kj::size(node.keys) && node.keys[indexInParent] == oldRow) { node.keys[indexInParent] = newRow; } KJ_DASSERT(pos != 0); @@ -715,7 +731,7 @@ void BTreeImpl::merge(Leaf& dst, uint dstPos, uint pivot, Leaf& src) { KJ_DASSERT(dst.isHalfFull()); constexpr size_t mid = Leaf::NROWS/2; - dst.rows[mid] = pivot; + KJ_DASSERT(dst.rows[mid-1] == pivot); acopy(dst.rows + mid, src.rows, mid); dst.next = src.next; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/table.h b/libs/EXTERNAL/capnproto/c++/src/kj/table.h index f3e19ddcb0f..d5d1b413718 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/table.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/table.h @@ -35,10 +35,21 @@ #endif #endif +#if KJ_DEBUG_TABLE_IMPL +#include "debug.h" +#define KJ_TABLE_IREQUIRE KJ_REQUIRE +#define KJ_TABLE_IASSERT KJ_ASSERT +#else +#define KJ_TABLE_IREQUIRE KJ_IREQUIRE +#define KJ_TABLE_IASSERT KJ_IASSERT +#endif + KJ_BEGIN_HEADER namespace kj { +class String; + namespace _ { // private template @@ -325,7 +336,7 @@ class HashIndex; // // methods to match this row. // // bool matches(const Row&, SearchParams&&...) const; -// // Returns true if the row on the left matches thes search params on the right. +// // Returns true if the row on the left matches the search params on the right. // // uint hashCode(SearchParams&&...) const; // // Computes the hash code of the given search params. Matching rows (as determined by @@ -761,7 +772,7 @@ void Table::verify() { template void Table::erase(Row& row) { - KJ_IREQUIRE(&row >= rows.begin() && &row < rows.end(), "row is not a member of this table"); + KJ_TABLE_IREQUIRE(&row >= rows.begin() && &row < rows.end(), "row is not a member of this table"); eraseImpl(&row - rows.begin()); } template @@ -777,7 +788,7 @@ void Table::eraseImpl(size_t pos) { template Row Table::release(Row& row) { - KJ_IREQUIRE(&row >= rows.begin() && &row < rows.end(), "row is not a member of this table"); + KJ_TABLE_IREQUIRE(&row >= rows.begin() && &row < rows.end(), "row is not a member of this table"); size_t pos = &row - rows.begin(); Impl<>::erase(*this, pos, row); Row result = kj::mv(row); @@ -863,7 +874,7 @@ struct HashBucket { inline const Row& getRow(ArrayPtr table) const { return table[getPos()]; } inline bool isPos(uint pos) const { return pos + 2 == value; } inline uint getPos() const { - KJ_IASSERT(value >= 2); + KJ_TABLE_IASSERT(value >= 2); return value - 2; } inline void setEmpty() { value = 0; } @@ -906,7 +917,7 @@ class HashIndex { void clear() { erasedCount = 0; - memset(buckets.begin(), 0, buckets.asBytes().size()); + if (buckets.size() > 0) memset(buckets.begin(), 0, buckets.asBytes().size()); } template @@ -1166,12 +1177,14 @@ class BTreeImpl::MaybeUint { inline MaybeUint& operator=(decltype(nullptr)) { i = 0; return *this; } inline MaybeUint& operator=(uint j) { i = j + 1; return *this; } - inline uint operator*() const { KJ_IREQUIRE(i != 0); return i - 1; } + inline uint operator*() const { KJ_TABLE_IREQUIRE(i != 0); return i - 1; } template inline bool check(Func& func) const { return i != 0 && func(i - 1); } // Equivalent to *this != nullptr && func(**this) + kj::String toString() const; + private: uint i; }; @@ -1190,14 +1203,14 @@ struct BTreeImpl::Leaf { inline bool isHalfFull() const; inline void insert(uint i, uint newRow) { - KJ_IREQUIRE(rows[Leaf::NROWS - 1] == nullptr); // check not full + KJ_TABLE_IREQUIRE(rows[Leaf::NROWS - 1] == nullptr); // check not full amove(rows + i + 1, rows + i, Leaf::NROWS - (i + 1)); rows[i] = newRow; } inline void erase(uint i) { - KJ_IREQUIRE(rows[0] != nullptr); // check not empty + KJ_TABLE_IREQUIRE(rows[0] != nullptr); // check not empty amove(rows + i, rows + i + 1, Leaf::NROWS - (i + 1)); rows[Leaf::NROWS - 1] = nullptr; @@ -1241,6 +1254,12 @@ struct BTreeImpl::Parent { static constexpr size_t NKEYS = 7; MaybeUint keys[NKEYS]; // Pointers to table rows, offset by 1 so that 0 is an empty value. + // + // Each keys[i] specifies the table row which is the "last" row found under children[i]. + // + // Note that `keys` has size 7 but `children` has size 8. `children[8]`'s "last row" is not + // recorded here, because the Parent's Parent records it instead. (Or maybe the Parent's Parent's + // Parent, if this Parent is `children[8]` of its own Parent. And so on.) static constexpr size_t NCHILDREN = NKEYS + 1; uint children[NCHILDREN]; @@ -1324,7 +1343,7 @@ bool BTreeImpl::Leaf::isMostlyFull() const { return rows[Leaf::NROWS / 2] != nullptr; } bool BTreeImpl::Leaf::isHalfFull() const { - KJ_IASSERT(rows[Leaf::NROWS / 2 - 1] != nullptr); + KJ_TABLE_IASSERT(rows[Leaf::NROWS / 2 - 1] != nullptr); return rows[Leaf::NROWS / 2] == nullptr; } @@ -1335,7 +1354,7 @@ bool BTreeImpl::Parent::isMostlyFull() const { return keys[Parent::NKEYS / 2] != nullptr; } bool BTreeImpl::Parent::isHalfFull() const { - KJ_IASSERT(keys[Parent::NKEYS / 2 - 1] != nullptr); + KJ_TABLE_IASSERT(keys[Parent::NKEYS / 2 - 1] != nullptr); return keys[Parent::NKEYS / 2] == nullptr; } @@ -1345,13 +1364,13 @@ class BTreeImpl::Iterator { : tree(tree), leaf(leaf), row(row) {} size_t operator*() const { - KJ_IREQUIRE(row < Leaf::NROWS && leaf->rows[row] != nullptr, + KJ_TABLE_IREQUIRE(row < Leaf::NROWS && leaf->rows[row] != nullptr, "tried to dereference end() iterator"); return *leaf->rows[row]; } inline Iterator& operator++() { - KJ_IREQUIRE(leaf->rows[row] != nullptr, "B-tree iterator overflow"); + KJ_TABLE_IREQUIRE(leaf->rows[row] != nullptr, "B-tree iterator overflow"); ++row; if (row >= Leaf::NROWS || leaf->rows[row] == nullptr) { if (leaf->next == 0) { @@ -1371,7 +1390,7 @@ class BTreeImpl::Iterator { inline Iterator& operator--() { if (row == 0) { - KJ_IREQUIRE(leaf->prev != 0, "B-tree iterator underflow"); + KJ_TABLE_IREQUIRE(leaf->prev != 0, "B-tree iterator underflow"); leaf = &tree[leaf->prev].leaf; row = leaf->size() - 1; } else { @@ -1397,17 +1416,17 @@ class BTreeImpl::Iterator { } void insert(BTreeImpl& impl, uint newRow) { - KJ_IASSERT(impl.tree == tree); + KJ_TABLE_IASSERT(impl.tree == tree); const_cast(leaf)->insert(row, newRow); } void erase(BTreeImpl& impl) { - KJ_IASSERT(impl.tree == tree); + KJ_TABLE_IASSERT(impl.tree == tree); const_cast(leaf)->erase(row); } void replace(BTreeImpl& impl, uint newRow) { - KJ_IASSERT(impl.tree == tree); + KJ_TABLE_IASSERT(impl.tree == tree); const_cast(leaf)->rows[row] = newRow; } @@ -1465,7 +1484,7 @@ class TreeIndex { template void erase(kj::ArrayPtr table, size_t pos, Params&&... params) { - impl.erase(pos, searchKey(table, params...)); + impl.erase(pos, searchKeyForErase(table, pos, params...)); } template @@ -1518,6 +1537,16 @@ class TreeIndex { auto predicate = [&](uint i) { return cb.isBefore(table[i], params...); }; return SearchKeyImpl(kj::mv(predicate)); } + + template + inline auto searchKeyForErase(kj::ArrayPtr& table, uint pos, Params&... params) const { + // When erasing, the table entry for the erased row may already be invalid, so we must avoid + // accessing it. + auto predicate = [&,pos](uint i) { + return i != pos && cb.isBefore(table[i], params...); + }; + return SearchKeyImpl(kj::mv(predicate)); + } }; // ----------------------------------------------------------------------------- @@ -1542,7 +1571,7 @@ class InsertionOrderIndex { : links(links), pos(pos) {} inline size_t operator*() const { - KJ_IREQUIRE(pos != 0, "can't derefrence end() iterator"); + KJ_TABLE_IREQUIRE(pos != 0, "can't dereference end() iterator"); return pos - 1; }; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/test-helpers.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/test-helpers.c++ index c4fcbc19fd5..6ae8cd32598 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/test-helpers.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/test-helpers.c++ @@ -34,8 +34,6 @@ #include #endif -#include - namespace kj { namespace _ { // private @@ -49,9 +47,6 @@ bool hasSubstring(StringPtr haystack, StringPtr needle) { #if !defined(_WIN32) return memmem(haystack.begin(), haystack.size(), needle.begin(), needle.size()) != nullptr; -#elif defined(__cpp_lib_boyer_moore_searcher) - std::boyer_moore_horspool_searcher searcher{needle.begin(), needle.size()}; - return std::search(haystack.begin(), haystack.end(), searcher) != haystack.end(); #else // TODO(perf): This is not the best algorithm for substring matching. strstr can't be used // because this is supposed to be safe to call on strings with embedded nulls. @@ -152,7 +147,75 @@ bool expectFatalThrow(kj::Maybe type, kj::Maybe mess KJ_FAIL_EXPECT("subprocess crashed without throwing exception", WTERMSIG(status)); return false; } else { - KJ_FAIL_EXPECT("subprocess neiter excited nor crashed?", status); + KJ_FAIL_EXPECT("subprocess neither excited nor crashed?", status); + return false; + } +#endif +} + +bool expectExit(Maybe statusCode, FunctionParam code) noexcept { +#if _WIN32 + // We don't support death tests on Windows due to lack of efficient fork. + return true; +#else + pid_t child; + KJ_SYSCALL(child = fork()); + if (child == 0) { + code(); + _exit(0); + } + + int status; + KJ_SYSCALL(waitpid(child, &status, 0)); + + if (WIFEXITED(status)) { + KJ_IF_MAYBE(s, statusCode) { + KJ_EXPECT(WEXITSTATUS(status) == *s); + return WEXITSTATUS(status) == *s; + } else { + KJ_EXPECT(WEXITSTATUS(status) != 0); + return WEXITSTATUS(status) != 0; + } + } else { + if (WIFSIGNALED(status)) { + KJ_FAIL_EXPECT("subprocess didn't exit but triggered a signal", strsignal(WTERMSIG(status))); + } else { + KJ_FAIL_EXPECT("subprocess didn't exit and didn't trigger a signal", status); + } + return false; + } +#endif +} + + +bool expectSignal(Maybe signal, FunctionParam code) noexcept { +#if _WIN32 + // We don't support death tests on Windows due to lack of efficient fork. + return true; +#else + pid_t child; + KJ_SYSCALL(child = fork()); + if (child == 0) { + resetCrashHandlers(); + code(); + _exit(0); + } + + int status; + KJ_SYSCALL(waitpid(child, &status, 0)); + + if (WIFSIGNALED(status)) { + KJ_IF_MAYBE(s, signal) { + KJ_EXPECT(WTERMSIG(status) == *s); + return WTERMSIG(status) == *s; + } + return true; + } else { + if (WIFEXITED(status)) { + KJ_FAIL_EXPECT("subprocess didn't trigger a signal but exited", WEXITSTATUS(status)); + } else { + KJ_FAIL_EXPECT("subprocess didn't exit and didn't trigger a signal", status); + } return false; } #endif diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/test-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/test-test.c++ index 7d020270940..b69eaf6ab76 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/test-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/test-test.c++ @@ -21,6 +21,13 @@ #include "common.h" #include "test.h" +#include +#include +#include + +#ifndef _WIN32 +#include +#endif namespace kj { namespace _ { @@ -78,6 +85,26 @@ KJ_TEST("GlobFilter") { } } +KJ_TEST("expect exit from exit") { + KJ_EXPECT_EXIT(42, _exit(42)); + KJ_EXPECT_EXIT(nullptr, _exit(42)); +} + +#if !KJ_NO_EXCEPTIONS +KJ_TEST("expect exit from thrown exception") { + KJ_EXPECT_EXIT(1, throw std::logic_error("test error")); +} +#endif + +KJ_TEST("expect signal from abort") { + KJ_EXPECT_SIGNAL(SIGABRT, abort()); +} + +KJ_TEST("expect signal from sigint") { + KJ_EXPECT_SIGNAL(SIGINT, raise(SIGINT)); + KJ_EXPECT_SIGNAL(nullptr, raise(SIGINT)); +} + } // namespace } // namespace _ } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/test.c++ index e988320f894..e310f20f05c 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/test.c++ @@ -42,6 +42,8 @@ namespace { TestCase* testCasesHead = nullptr; TestCase** testCasesTail = &testCasesHead; +size_t benchmarkIterCount = 1; + } // namespace TestCase::TestCase(const char* file, uint line, const char* description) @@ -60,6 +62,10 @@ TestCase::~TestCase() { } } +size_t TestCase::iterCount() { + return benchmarkIterCount; +} + // ======================================================================================= namespace _ { // private @@ -172,7 +178,7 @@ public: if (severity == LogSeverity::ERROR || severity == LogSeverity::FATAL) { sawError = true; - context.error(kj::str(text, "\nstack: ", strArray(trace, " "), stringifyStackTrace(trace))); + context.error(kj::str(text, "\nstack: ", stringifyStackTraceAddresses(trace), stringifyStackTrace(trace))); } else { context.warning(text); } @@ -205,6 +211,9 @@ public: .addOption({'l', "list"}, KJ_BIND_METHOD(*this, setList), "List all test cases that would run, but don't run them. If --filter is specified " "then only the match tests will be listed.") + .addOptionWithArg({'b', "benchmark"}, KJ_BIND_METHOD(*this, setBenchmarkIters), "", + "Specifies that any benchmarks in the tests should run for iterations. " + "If not specified, then count is 1, which simply tests that the benchmarks function.") .callAfterParsing(KJ_BIND_METHOD(*this, run)) .build(); } @@ -263,6 +272,15 @@ public: return true; } + MainBuilder::Validity setBenchmarkIters(StringPtr param) { + KJ_IF_MAYBE(i, param.tryParseAs()) { + benchmarkIterCount = *i; + return true; + } else { + return "expected an integer"; + } + } + MainBuilder::Validity run() { if (testCasesHead == nullptr) { return "no tests were declared"; diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/test.h b/libs/EXTERNAL/capnproto/c++/src/kj/test.h index fbc34492ecf..5acbb00d40b 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/test.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/test.h @@ -39,6 +39,21 @@ class TestCase { virtual void run() = 0; +protected: + template + void doBenchmark(Func&& func) { + // Perform a benchmark with configurable iterations. func() will be called N times, where N + // is set by the --benchmark CLI flag. This defaults to 1, so that when --benchmark is not + // specified, we only test that the benchmark works. + // + // In the future, this could adaptively choose iteration count by running a few iterations to + // find out how fast the benchmark is, then scaling. + + for (size_t i = iterCount(); i-- > 0;) { + func(); + } + } + private: const char* file; uint line; @@ -47,6 +62,8 @@ class TestCase { TestCase** prev; bool matchedFilter; + static size_t iterCount(); + friend class TestRunner; }; @@ -60,7 +77,7 @@ class TestCase { } KJ_UNIQUE_NAME(testCase); \ void KJ_UNIQUE_NAME(TestCase)::run() -#if _MSC_VER && !defined(__clang__) +#if KJ_MSVC_TRADITIONAL_CPP #define KJ_INDIRECT_EXPAND(m, vargs) m vargs #define KJ_FAIL_EXPECT(...) \ KJ_INDIRECT_EXPAND(KJ_LOG, (ERROR , __VA_ARGS__)); @@ -75,32 +92,54 @@ class TestCase { else KJ_FAIL_EXPECT("failed: expected " #cond, _kjCondition, ##__VA_ARGS__) #endif -#define KJ_EXPECT_THROW_RECOVERABLE(type, code) \ +#if _MSC_VER && !defined(__clang__) +#define KJ_EXPECT_THROW_RECOVERABLE(type, code, ...) \ + do { \ + KJ_IF_MAYBE(e, ::kj::runCatchingExceptions([&]() { code; })) { \ + KJ_INDIRECT_EXPAND(KJ_EXPECT, (e->getType() == ::kj::Exception::Type::type, \ + "code threw wrong exception type: " #code, *e, __VA_ARGS__)); \ + } else { \ + KJ_INDIRECT_EXPAND(KJ_FAIL_EXPECT, ("code did not throw: " #code, __VA_ARGS__)); \ + } \ + } while (false) + +#define KJ_EXPECT_THROW_RECOVERABLE_MESSAGE(message, code, ...) \ + do { \ + KJ_IF_MAYBE(e, ::kj::runCatchingExceptions([&]() { code; })) { \ + KJ_INDIRECT_EXPAND(KJ_EXPECT, (::kj::_::hasSubstring(e->getDescription(), message), \ + "exception description didn't contain expected substring", *e, __VA_ARGS__)); \ + } else { \ + KJ_INDIRECT_EXPAND(KJ_FAIL_EXPECT, ("code did not throw: " #code, __VA_ARGS__)); \ + } \ + } while (false) +#else +#define KJ_EXPECT_THROW_RECOVERABLE(type, code, ...) \ do { \ KJ_IF_MAYBE(e, ::kj::runCatchingExceptions([&]() { code; })) { \ KJ_EXPECT(e->getType() == ::kj::Exception::Type::type, \ - "code threw wrong exception type: " #code, *e); \ + "code threw wrong exception type: " #code, *e, ##__VA_ARGS__); \ } else { \ - KJ_FAIL_EXPECT("code did not throw: " #code); \ + KJ_FAIL_EXPECT("code did not throw: " #code, ##__VA_ARGS__); \ } \ } while (false) -#define KJ_EXPECT_THROW_RECOVERABLE_MESSAGE(message, code) \ +#define KJ_EXPECT_THROW_RECOVERABLE_MESSAGE(message, code, ...) \ do { \ KJ_IF_MAYBE(e, ::kj::runCatchingExceptions([&]() { code; })) { \ KJ_EXPECT(::kj::_::hasSubstring(e->getDescription(), message), \ - "exception description didn't contain expected substring", *e); \ + "exception description didn't contain expected substring", *e, ##__VA_ARGS__); \ } else { \ - KJ_FAIL_EXPECT("code did not throw: " #code); \ + KJ_FAIL_EXPECT("code did not throw: " #code, ##__VA_ARGS__); \ } \ } while (false) +#endif #if KJ_NO_EXCEPTIONS -#define KJ_EXPECT_THROW(type, code) \ +#define KJ_EXPECT_THROW(type, code, ...) \ do { \ KJ_EXPECT(::kj::_::expectFatalThrow(::kj::Exception::Type::type, nullptr, [&]() { code; })); \ } while (false) -#define KJ_EXPECT_THROW_MESSAGE(message, code) \ +#define KJ_EXPECT_THROW_MESSAGE(message, code, ...) \ do { \ KJ_EXPECT(::kj::_::expectFatalThrow(nullptr, kj::StringPtr(message), [&]() { code; })); \ } while (false) @@ -109,6 +148,19 @@ class TestCase { #define KJ_EXPECT_THROW_MESSAGE KJ_EXPECT_THROW_RECOVERABLE_MESSAGE #endif +#define KJ_EXPECT_EXIT(statusCode, code) \ + do { \ + KJ_EXPECT(::kj::_::expectExit(statusCode, [&]() { code; })); \ + } while (false) +// Forks the code and expects it to exit with a given code. + +#define KJ_EXPECT_SIGNAL(signal, code) \ + do { \ + KJ_EXPECT(::kj::_::expectSignal(signal, [&]() { code; })); \ + } while (false) +// Forks the code and expects it to trigger a signal. +// In the child resets all signal handlers as printStackTraceOnCrash sets. + #define KJ_EXPECT_LOG(level, substring) \ ::kj::_::LogExpectation KJ_UNIQUE_NAME(_kjLogExpectation)(::kj::LogSeverity::level, substring) // Expects that a log message with the given level and substring text will be printed within @@ -128,6 +180,17 @@ bool expectFatalThrow(Maybe type, Maybe message, // fork() is not available, this always returns true. #endif +bool expectExit(Maybe statusCode, FunctionParam code) noexcept; +// Expects that the given code will exit with a given statusCode. +// The test will fork() and run in a subprocess. On Windows, where fork() is not available, +// this always returns true. + +bool expectSignal(Maybe signal, FunctionParam code) noexcept; +// Expects that the given code will trigger a signal. +// The test will fork() and run in a subprocess. On Windows, where fork() is not available, +// this always returns true. +// Resets signal handlers to default prior to running the code in the child process. + class LogExpectation: public ExceptionCallback { public: LogExpectation(LogSeverity severity, StringPtr substring); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/thread.h b/libs/EXTERNAL/capnproto/c++/src/kj/thread.h index 46fd39bb9c2..2261ab12c91 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/thread.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/thread.h @@ -36,7 +36,7 @@ class Thread { public: explicit Thread(Function func); - KJ_DISALLOW_COPY(Thread); + KJ_DISALLOW_COPY_AND_MOVE(Thread); ~Thread() noexcept(false); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/time-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/time-test.c++ index 0dd5d64b853..8bbdf9366b6 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/time-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/time-test.c++ @@ -46,6 +46,10 @@ KJ_TEST("stringify times") { KJ_EXPECT(kj::str(50 * kj::MICROSECONDS) == "50μs"); KJ_EXPECT(kj::str(5 * kj::MICROSECONDS + 300 * kj::NANOSECONDS) == "5.3μs"); KJ_EXPECT(kj::str(50 * kj::NANOSECONDS) == "50ns"); + KJ_EXPECT(kj::str(-256 * kj::MILLISECONDS) == "-256ms"); + KJ_EXPECT(kj::str(-50 * kj::NANOSECONDS) == "-50ns"); + KJ_EXPECT(kj::str((int64_t)kj::maxValue * kj::NANOSECONDS) == "9223372036.854775807s"); + KJ_EXPECT(kj::str((int64_t)kj::minValue * kj::NANOSECONDS) == "-9223372036.854775808s"); } #if _WIN32 diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/time.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/time.c++ index d0846588415..98ebbb1f7b6 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/time.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/time.c++ @@ -261,14 +261,20 @@ const MonotonicClock& systemPreciseMonotonicClock() { #endif -kj::String KJ_STRINGIFY(TimePoint t) { +CappedArray KJ_STRINGIFY(TimePoint t) { return kj::toCharSequence(t - kj::origin()); } -kj::String KJ_STRINGIFY(Date d) { +CappedArray KJ_STRINGIFY(Date d) { return kj::toCharSequence(d - UNIX_EPOCH); } -kj::String KJ_STRINGIFY(Duration d) { - auto digits = kj::toCharSequence(d / kj::NANOSECONDS); +CappedArray KJ_STRINGIFY(Duration d) { + bool negative = d < 0 * kj::SECONDS; + uint64_t ns = d / kj::NANOSECONDS; + if (negative) { + ns = -ns; + } + + auto digits = kj::toCharSequence(ns); ArrayPtr arr = digits; size_t point; @@ -292,15 +298,24 @@ kj::String KJ_STRINGIFY(Duration d) { unit = kj::NANOSECONDS; } + CappedArray result; + char* begin = result.begin(); + char* end; + if (negative) { + *begin++ = '-'; + } if (d % unit == 0 * kj::NANOSECONDS) { - return kj::str(arr.slice(0, point), suffix); + end = _::fillLimited(begin, result.end(), arr.slice(0, point), suffix); } else { while (arr.back() == '0') { arr = arr.slice(0, arr.size() - 1); } - KJ_ASSERT(arr.size() > point); - return kj::str(arr.slice(0, point), ".", arr.slice(point, arr.size()), suffix); + KJ_DASSERT(arr.size() > point); + end = _::fillLimited(begin, result.end(), arr.slice(0, point), "."_kj, + arr.slice(point, arr.size()), suffix); } + result.setSize(end - result.begin()); + return result; } } // namespace kj diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/time.h b/libs/EXTERNAL/capnproto/c++/src/kj/time.h index 0c2e47af718..aaf1031d7e4 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/time.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/time.h @@ -35,6 +35,10 @@ class NanosecondLabel; class TimeLabel; class DateLabel; +static constexpr size_t TIME_STR_LEN = sizeof(int64_t) * 3 + 8; +// Maximum length of a stringified time. 3 digits per byte of integer, plus 8 digits to cover +// negative sign, decimal point, unit, NUL terminator, and anything else that might sneak in. + } // namespace _ (private) using Duration = Quantity; @@ -56,9 +60,9 @@ using TimePoint = Absolute; using Date = Absolute; // A point in real-world time, measured relative to the Unix epoch (Jan 1, 1970 00:00:00 UTC). -kj::String KJ_STRINGIFY(TimePoint); -kj::String KJ_STRINGIFY(Date); -kj::String KJ_STRINGIFY(Duration); +CappedArray KJ_STRINGIFY(TimePoint); +CappedArray KJ_STRINGIFY(Date); +CappedArray KJ_STRINGIFY(Duration); constexpr Date UNIX_EPOCH = origin(); // The `Date` representing Jan 1, 1970 00:00:00 UTC. @@ -110,7 +114,6 @@ const MonotonicClock& systemPreciseMonotonicClock(); // The "coarse" version has precision around 1-10ms, while the "precise" version has precision // better than 1us. The "precise" version may be slightly slower, though on modern hardware and // a reasonable operating system the difference is usually negligible. - } // namespace kj KJ_END_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/timer.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/timer.c++ index 993b683ce24..e5cd26484ed 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/timer.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/timer.c++ @@ -110,9 +110,17 @@ Maybe TimerImpl::timeoutToNextEvent(TimePoint start, Duration unit, ui } void TimerImpl::advanceTo(TimePoint newTime) { + // On Macs running an Intel processor, it has been observed that clock_gettime + // may return non monotonic time, even when CLOCK_MONOTONIC is used. + // This workaround is to avoid the assert triggering on these machines. + // See also https://github.com/capnproto/capnproto/issues/1693 +#if __APPLE__ && defined(__x86_64__) + time = std::max(time, newTime); +#else KJ_REQUIRE(newTime >= time, "can't advance backwards in time") { return; } - time = newTime; +#endif + for (;;) { auto front = impl->timers.begin(); if (front == impl->timers.end() || (*front)->time > time) { diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/timer.h b/libs/EXTERNAL/capnproto/c++/src/kj/timer.h index 862f97b9c79..eb9443c23bc 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/timer.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/timer.h @@ -22,7 +22,7 @@ #pragma once -#include "time.h" +#include #include "async.h" KJ_BEGIN_HEADER diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/tuple.h b/libs/EXTERNAL/capnproto/c++/src/kj/tuple.h index 2a526c0c329..1351912b099 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/tuple.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/tuple.h @@ -98,7 +98,7 @@ struct TupleElement { template struct TupleElement { - // A tuple containing references can be constucted using refTuple(). + // A tuple containing references can be constructed using refTuple(). T& value; constexpr inline TupleElement(T& value): value(value) {} diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/units-test.c++ b/libs/EXTERNAL/capnproto/c++/src/kj/units-test.c++ index 892c1d39862..31a29737937 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/units-test.c++ +++ b/libs/EXTERNAL/capnproto/c++/src/kj/units-test.c++ @@ -341,7 +341,7 @@ TEST(UnitMeasure, BoundedMinMax) { assertTypeAndValue(boundedValue<4,t1>(2), kj::min(bounded<4>(), boundedValue<5,t1>(2))); assertTypeAndValue(boundedValue<4,t1>(2), kj::min(boundedValue<5,t1>(2), bounded<4>())); - // These two are degenerate cases. Currently they fail to compile but meybe they shouldn't? + // These two are degenerate cases. Currently they fail to compile but maybe they shouldn't? // assertTypeAndValue(bounded<5>(), kj::max(boundedValue<4,t2>(3), bounded<5>())); // assertTypeAndValue(bounded<5>(), kj::max(bounded<5>(), boundedValue<4,t2>(3))); diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/units.h b/libs/EXTERNAL/capnproto/c++/src/kj/units.h index e843b12433d..530abafbee5 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/units.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/units.h @@ -74,19 +74,6 @@ class Bounded; template class BoundedConst; -template constexpr bool isIntegral() { return false; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } - template struct IsIntegralOrBounded_ { static constexpr bool value = isIntegral(); }; template diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/vector.h b/libs/EXTERNAL/capnproto/c++/src/kj/vector.h index 60a370a0f26..d072448fc6b 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/vector.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/vector.h @@ -123,7 +123,7 @@ class Vector { inline void reserve(size_t size) { if (size > builder.capacity()) { - setCapacity(size); + grow(size); } } diff --git a/libs/EXTERNAL/capnproto/c++/src/kj/windows-sanity.h b/libs/EXTERNAL/capnproto/c++/src/kj/windows-sanity.h index 64475dc41dd..b2c93678d62 100644 --- a/libs/EXTERNAL/capnproto/c++/src/kj/windows-sanity.h +++ b/libs/EXTERNAL/capnproto/c++/src/kj/windows-sanity.h @@ -48,10 +48,13 @@ // now, we use `#pragma once` to tell the compiler never to include this file again. #pragma once -namespace win32 { - const auto ERROR_ = ERROR; +namespace kj_win32_workarounds { + // Namespace containing constant definitions intended to replace constants that are defined as + // macros in the Windows headers. Do not refer to this namespace directly, we'll import it into + // the global scope below. #ifdef ERROR // This could be absent if e.g. NOGDI was used. + const auto ERROR_ = ERROR; #undef ERROR const auto ERROR = ERROR_; #endif @@ -61,7 +64,8 @@ namespace win32 { typedef VOID_ VOID; } -using win32::ERROR; -using win32::VOID; +// Pull our constant definitions into the global namespace -- but only if they don't already exist +// in the global namespace. +using namespace kj_win32_workarounds; #endif diff --git a/libs/EXTERNAL/capnproto/doc/_includes/buttons.html b/libs/EXTERNAL/capnproto/doc/_includes/buttons.html index 7a7ec7567f7..66ec6248ca2 100644 --- a/libs/EXTERNAL/capnproto/doc/_includes/buttons.html +++ b/libs/EXTERNAL/capnproto/doc/_includes/buttons.html @@ -1,6 +1,6 @@
Develop +href="https://github.com/capnproto/capnproto">Develop Discuss diff --git a/libs/EXTERNAL/capnproto/doc/_includes/header.html b/libs/EXTERNAL/capnproto/doc/_includes/header.html index f249d017301..e835a3e3237 100644 --- a/libs/EXTERNAL/capnproto/doc/_includes/header.html +++ b/libs/EXTERNAL/capnproto/doc/_includes/header.html @@ -26,7 +26,7 @@
Discuss on Groups - View on GitHub + View on GitHub {% if page.title != "Introduction" %}{% endif %} diff --git a/libs/EXTERNAL/capnproto/doc/_posts/2015-01-23-capnproto-0.5.1-bugfixes.md b/libs/EXTERNAL/capnproto/doc/_posts/2015-01-23-capnproto-0.5.1-bugfixes.md index 345d756672c..52f11433387 100644 --- a/libs/EXTERNAL/capnproto/doc/_posts/2015-01-23-capnproto-0.5.1-bugfixes.md +++ b/libs/EXTERNAL/capnproto/doc/_posts/2015-01-23-capnproto-0.5.1-bugfixes.md @@ -13,4 +13,4 @@ Cap'n Proto 0.5.1 has just been released with some bug fixes: Sorry about the bugs. -In other news, as you can see, the Cap'n Proto web site now lives at `capnproto.org`. Additionally, the Github repo has been moved to the [Sandstorm.io organization](https://github.com/sandstorm-io). Both moves have left behind redirects so that old links / repository references should continue to work. +In other news, as you can see, the Cap'n Proto web site now lives at `capnproto.org`. Additionally, the Github repo has been moved to the [Sandstorm.io organization](https://github.com/capnproto). Both moves have left behind redirects so that old links / repository references should continue to work. diff --git a/libs/EXTERNAL/capnproto/doc/_posts/2015-03-02-security-advisory-and-integer-overflow-protection.md b/libs/EXTERNAL/capnproto/doc/_posts/2015-03-02-security-advisory-and-integer-overflow-protection.md index 1834524d17b..eaa2b16158f 100644 --- a/libs/EXTERNAL/capnproto/doc/_posts/2015-03-02-security-advisory-and-integer-overflow-protection.md +++ b/libs/EXTERNAL/capnproto/doc/_posts/2015-03-02-security-advisory-and-integer-overflow-protection.md @@ -6,11 +6,11 @@ author: kentonv As the installation page has always stated, I do not yet recommend using Cap'n Proto's C++ library for handling possibly-malicious input, and will not recommend it until it undergoes a formal security review. That said, security is obviously a high priority for the project. The security of Cap'n Proto is in fact essential to the security of [Sandstorm.io](https://sandstorm.io), Cap'n Proto's parent project, in which sandboxed apps communicate with each other and the platform via Cap'n Proto RPC. -A few days ago, the first major security bugs were found in Cap'n Proto C++ -- two by security guru [Ben Laurie](http://en.wikipedia.org/wiki/Ben_Laurie) and one by myself during subsequent review (see below). You can read details about each bug in our new [security advisories directory](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories): +A few days ago, the first major security bugs were found in Cap'n Proto C++ -- two by security guru [Ben Laurie](http://en.wikipedia.org/wiki/Ben_Laurie) and one by myself during subsequent review (see below). You can read details about each bug in our new [security advisories directory](https://github.com/capnproto/capnproto/tree/master/security-advisories): -* [Integer overflow in pointer validation.](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md) -* [Integer underflow in pointer validation.](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md) -* [CPU usage amplification attack.](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-2-all-cpu-amplification.md) +* [Integer overflow in pointer validation.](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md) +* [Integer underflow in pointer validation.](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md) +* [CPU usage amplification attack.](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-2-all-cpu-amplification.md) I have backported the fixes to the last two release branches -- 0.5 and 0.4: @@ -109,15 +109,15 @@ So, a `Guarded<10, int>` represents a `int` which is statically guaranteed to ho Moreover, because all of `Guarded`'s operators are inline and `constexpr`, a good optimizing compiler will be able to optimize `Guarded` down to the underlying primitive integer type. So, in theory, using `Guarded` has no runtime overhead. (I have not yet verified that real compilers get this right, but I suspect they do.) -Of course, the full implementation is considerably more complicated than this. The code has not been merged into the Cap'n Proto tree yet as we need to do more analysis to make sure it has no negative impact. For now, you can find it in the [overflow-safe](https://github.com/sandstorm-io/capnproto/tree/overflow-safe) branch, specifically in the second half of [kj/units.h](https://github.com/sandstorm-io/capnproto/blob/overflow-safe/c++/src/kj/units.h). (This header also contains metaprogramming for compile-time unit analysis, which Cap'n Proto has been using since its first release.) +Of course, the full implementation is considerably more complicated than this. The code has not been merged into the Cap'n Proto tree yet as we need to do more analysis to make sure it has no negative impact. For now, you can find it in the [overflow-safe](https://github.com/capnproto/capnproto/tree/overflow-safe) branch, specifically in the second half of [kj/units.h](https://github.com/capnproto/capnproto/blob/overflow-safe/c++/src/kj/units.h). (This header also contains metaprogramming for compile-time unit analysis, which Cap'n Proto has been using since its first release.) ### Results I switched Cap'n Proto's core pointer validation code (`capnp/layout.c++`) over to `Guarded`. In the process, I found: * Several overflows that could be triggered by the application calling methods with invalid parameters, but not by a remote attacker providing invalid message data. We will change the code to check these in the future, but they are not critical security problems. -* The overflow that Ben had already reported ([2015-03-02-0](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md)). I had intentionally left this unfixed during my analysis to verify that `Guarded` would catch it. -* One otherwise-undiscovered integer underflow ([2015-03-02-1](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md)). +* The overflow that Ben had already reported ([2015-03-02-0](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md)). I had intentionally left this unfixed during my analysis to verify that `Guarded` would catch it. +* One otherwise-undiscovered integer underflow ([2015-03-02-1](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md)). Based on these results, I conclude that `Guarded` is in fact effective at finding overflow bugs, and that such bugs are thankfully _not_ endemic in Cap'n Proto's code. diff --git a/libs/EXTERNAL/capnproto/doc/_posts/2015-03-05-another-cpu-amplification.md b/libs/EXTERNAL/capnproto/doc/_posts/2015-03-05-another-cpu-amplification.md index d074c9db800..4c8ae97591f 100644 --- a/libs/EXTERNAL/capnproto/doc/_posts/2015-03-05-another-cpu-amplification.md +++ b/libs/EXTERNAL/capnproto/doc/_posts/2015-03-05-another-cpu-amplification.md @@ -8,7 +8,7 @@ Unfortunately, it turns out that our fix for one of [the security advisories iss Fortunately, the incomplete fix is for the non-critical vulnerability. The worst case is that an attacker could consume excessive CPU time. -Nevertheless, we've issued [a new advisory](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md) and pushed a new release: +Nevertheless, we've issued [a new advisory](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md) and pushed a new release: - Release 0.5.1.2: [source](https://capnproto.org/capnproto-c++-0.5.1.2.tar.gz), [win32](https://capnproto.org/capnproto-c++-win32-0.5.1.2.zip) - Release 0.4.1.2: [source](https://capnproto.org/capnproto-c++-0.4.1.2.tar.gz) diff --git a/libs/EXTERNAL/capnproto/doc/_posts/2017-05-01-capnproto-0.6-msvc-json-http-more.md b/libs/EXTERNAL/capnproto/doc/_posts/2017-05-01-capnproto-0.6-msvc-json-http-more.md index bd36cbbaced..336f0d838f6 100644 --- a/libs/EXTERNAL/capnproto/doc/_posts/2017-05-01-capnproto-0.6-msvc-json-http-more.md +++ b/libs/EXTERNAL/capnproto/doc/_posts/2017-05-01-capnproto-0.6-msvc-json-http-more.md @@ -40,7 +40,7 @@ The 0.6 release includes a number of measures designed to harden Cap'n Proto's C Cap'n Proto messages can now be converted to and from JSON using `libcapnp-json`. This makes it easy to integrate your JSON front-end API with your Cap'n Proto back-end. -See the capnp/compat/json.h header for API details. +See the capnp/compat/json.h header for API details. This library was primarily built by [**Kamal Marhubi**](https://github.com/kamalmarhubi) and [**Branislav Katreniak**](https://github.com/katreniak), using Cap'n Proto's [dynamic API]({{site.baseurl}}cxx.html#dynamic-reflection). @@ -48,7 +48,7 @@ This library was primarily built by [**Kamal Marhubi**](https://github.com/kamal KJ (the C++ framework library bundled with Cap'n Proto) now ships with a minimalist HTTP library, `libkj-http`. The library is based on the KJ asynchronous I/O framework and covers both client-side and server-side use cases. Although functional and used in production today, the library should be considered a work in progress -- expect improvements in future releases, such as client connection pooling and TLS support. -See the kj/compat/http.h header for API details. +See the kj/compat/http.h header for API details. #### Smaller things diff --git a/libs/EXTERNAL/capnproto/doc/_posts/2022-06-03-capnproto-0.10.md b/libs/EXTERNAL/capnproto/doc/_posts/2022-06-03-capnproto-0.10.md new file mode 100644 index 00000000000..999fd10fbc7 --- /dev/null +++ b/libs/EXTERNAL/capnproto/doc/_posts/2022-06-03-capnproto-0.10.md @@ -0,0 +1,12 @@ +--- +layout: post +title: "Cap'n Proto 0.10" +author: kentonv +--- + + + +Today I'm releasing Cap'n Proto 0.10. + +Like last time, there's no huge new features in this release, but there are many minor improvements and bug fixes. You can [read the PR history](https://github.com/capnproto/capnproto/pulls?q=is%3Apr+is%3Aclosed) to find out what has changed. diff --git a/libs/EXTERNAL/capnproto/doc/_posts/2022-11-30-CVE-2022-46149-security-advisory.md b/libs/EXTERNAL/capnproto/doc/_posts/2022-11-30-CVE-2022-46149-security-advisory.md new file mode 100644 index 00000000000..01284cca4aa --- /dev/null +++ b/libs/EXTERNAL/capnproto/doc/_posts/2022-11-30-CVE-2022-46149-security-advisory.md @@ -0,0 +1,13 @@ +--- +layout: post +title: "CVE-2022-46149: Possible out-of-bounds read related to list-of-pointers" +author: kentonv +--- + +David Renshaw, the author of the Rust implementation of Cap'n Proto, discovered a security vulnerability affecting both the C++ and Rust implementations of Cap'n Proto. The vulnerability was discovered using fuzzing. In theory, the vulnerability could lead to out-of-bounds reads which could cause crashes or perhaps exfiltration of memory. + +The vulnerability is exploitable only if an application performs a certain unusual set of actions. As of this writing, we are not aware of any applications that are actually affected. However, out of an abundance of caution, we are issuing a security advisory and advising everyone to patch. + +[Our security advisory](https://github.com/capnproto/capnproto/blob/master/security-advisories/2022-11-30-0-pointer-list-bounds.md) explains the impact of the bug, what an app must do to be affected, and where to find the fix. + +Check out [David's blog post](https://dwrensha.github.io/capnproto-rust/2022/11/30/out_of_bounds_memory_access_bug.html) for an in-depth explanation of the bug itself, including some of the inner workings of Cap'n Proto. diff --git a/libs/EXTERNAL/capnproto/doc/_posts/2023-07-28-capnproto-1.0.md b/libs/EXTERNAL/capnproto/doc/_posts/2023-07-28-capnproto-1.0.md new file mode 100644 index 00000000000..94073efa3c8 --- /dev/null +++ b/libs/EXTERNAL/capnproto/doc/_posts/2023-07-28-capnproto-1.0.md @@ -0,0 +1,74 @@ +--- +layout: post +title: "Cap'n Proto 1.0" +author: kentonv +--- + + + +It's been a little over ten years since the first release of Cap'n Proto, on April 1, 2013. Today I'm releasing version 1.0 of Cap'n Proto's C++ reference implementation. + +Don't get too excited! There's not actually much new. Frankly, I should have declared 1.0 a long time ago – probably around version 0.6 (in 2017) or maybe even 0.5 (in 2014). I didn't mostly because there were a few advanced features (like three-party handoff, or shared-memory RPC) that I always felt like I wanted to finish before 1.0, but they just kept not reaching the top of my priority list. But the reality is that Cap'n Proto has been relied upon in production for a long time. In fact, you are using Cap'n Proto right now, to view this site, which is served by Cloudflare, which uses Cap'n Proto extensively (and is also my employer, although they used Cap'n Proto before they hired me). Cap'n Proto is used to encode millions (maybe billions) of messages and gigabits (maybe terabits) of data every single second of every day. As for those still-missing features, the real world has seemingly proven that they aren't actually that important. (I still do want to complete them though.) + +Ironically, the thing that finally motivated the 1.0 release is so that we can start working on 2.0. But again here, don't get too excited! Cap'n Proto 2.0 is not slated to be a revolutionary change. Rather, there are a number of changes we (the Cloudflare Workers team) would like to make to Cap'n Proto's C++ API, and its companion, the KJ C++ toolkit library. Over the ten years these libraries have been available, I have kept their APIs pretty stable, despite being 0.x versioned. But for 2.0, we want to make some sweeping backwards-incompatible changes, in order to fix some footguns and improve developer experience for those on our team. + +Some users probably won't want to keep up with these changes. Hence, I'm releasing 1.0 now as a sort of "long-term support" release. We'll backport bugfixes as appropriate to the 1.0 branch for the long term, so that people who aren't interested in changes can just stick with it. + +## What's actually new in 1.0? + +Again, not a whole lot has changed since the last version, 0.10. But there are a few things worth mentioning: + +* A number of optimizations were made to improve performance of Cap'n Proto RPC. These include reducing the amount of memory allocation done by the RPC implementation and KJ I/O framework, adding the ability to elide certain messages from the RPC protocol to reduce traffic, and doing better buffering of small messages that are sent and received together to reduce syscalls. These are incremental improvements. + +* **Breaking change:** Previously, servers could opt into allowing RPC cancellation by calling `context.allowCancellation()` after a call was delivered. In 1.0, opting into cancellation is instead accomplished using an annotation on the schema (the `allowCancellation` annotation defined in `c++.capnp`). We made this change after observing that in practice, we almost always wanted to allow cancellation, but we almost always forgot to do so. The schema-level annotation can be set on a whole file at a time, which is easier not to forget. Moreover, the dynamic opt-in required a lot of bookkeeping that had a noticeable performance impact in practice; switching to the annotation provided a performance boost. For users that never used `context.allowCancellation()` in the first place, there's no need to change anything when upgrading to 1.0 – cancellation is still disallowed by default. (If you are affected, you will see a compile error. If there's no compile error, you have nothing to worry about.) + +* KJ now uses `kqueue()` to handle asynchronous I/O on systems that have it (MacOS and BSD derivatives). KJ has historically always used `epoll` on Linux, but until now had used a slower `poll()`-based approach on other Unix-like platforms. + +* KJ's HTTP client and server implementations now support the `CONNECT` method. + +* [A new class `capnp::RevocableServer` was introduced](https://github.com/capnproto/capnproto/pull/1700) to assist in exporting RPC wrappers around objects whose lifetimes are not controlled by the wrapper. Previously, avoiding use-after-free bugs in such scenarios was tricky. + +* Many, many smaller bug fixes and improvements. [See the PR history](https://github.com/capnproto/capnproto/pulls?q=is%3Apr+is%3Aclosed) for details. + +## What's planned for 2.0? + +The changes we have in mind for version 2.0 of Cap'n Proto's C++ implementation are mostly NOT related to the protocol itself, but rather to the C++ API and especially to KJ, the C++ toolkit library that comes with Cap'n Proto. These changes are motivated by our experience building a large codebase on top of KJ: namely, the Cloudflare Workers runtime, [`workerd`](https://github.com/cloudflare/workerd). + +KJ is a C++ toolkit library, arguably comparable to things like Boost, Google's Abseil, or Facebook's Folly. I started building KJ at the same time as Cap'n Proto in 2013, at a time when C++11 was very new and most libraries were not really designing around it yet. The intent was never to create a new standard library, but rather to address specific needs I had at the time. But over many years, I ended up building a lot of stuff. By the time I joined Cloudflare and started the Workers Runtime, KJ already featured a powerful async I/O framework, HTTP implementation, TLS bindings, and more. + +Of course, KJ has nowhere near as much stuff as Boost or Abseil, and nowhere near as much engineering effort behind it. You might argue, therefore, that it would have been better to choose one of those libraries to build on. However, KJ had a huge advantage: that we own it, and can shape it to fit our specific needs, without having to fight with anyone to get those changes upstreamed. + +One example among many: KJ's HTTP implementation features the ability to "suspend" the state of an HTTP connection, after receiving headers, and transfer it to a different thread or process to be resumed. This is an unusual thing to want, but is something we needed for resource management in the Workers Runtime. Implementing this required some deep surgery in KJ HTTP and definitely adds complexity. If we had been using someone else's HTTP library, would they have let us upstream such a change? + +That said, even though we own KJ, we've still tried to avoid making any change that breaks third-party users, and this has held back some changes that would probably benefit Cloudflare Workers. We have therefore decided to "fork" it. Version 2.0 is that fork. + +Development of version 2.0 will take place on Cap'n Proto's new `v2` branch. The `master` branch will become the 1.0 LTS branch, so that existing projects which track `master` are not disrupted by our changes. + +We don't yet know all the changes we want to make as we've only just started thinking seriously about it. But, here's some ideas we've had so far: + +* We will require a compiler with support for C++20, or maybe even C++23. Cap'n Proto 1.0 only requires C++14. + +* In particular, we will require a compiler that supports C++20 coroutines, as lots of KJ async code will be refactored to rely on coroutines. This should both make the code clearer and improve performance by reducing memory allocations. However, coroutine support is still spotty – as of this writing, GCC seems to ICE on KJ's coroutine implementation. + +* Cap'n Proto's RPC API, KJ's HTTP APIs, and others are likely to be revised to make them more coroutine-friendly. + +* `kj::Maybe` will become more ergonomic. It will no longer overload `nullptr` to represent the absence of a value; we will introduce `kj::none` instead. `KJ_IF_MAYBE` will no longer produce a pointer, but instead a reference (a trick that becomes possible by utilizing C++17 features). + +* We will drop support for compiling with exceptions disabled. KJ's coding style uses exceptions as a form of software fault isolation, or "catchable panics", such that errors can cause the "current task" to fail out without disrupting other tasks running concurrently. In practice, this ends up affecting every part of how KJ-style code is written. And yet, since the beginning, KJ and Cap'n Proto have been designed to accommodate environments where exceptions are turned off at compile time, using an elaborate system to fall back to callbacks and distinguish between fatal and non-fatal exceptions. In practice, maintaining this ability has been a drag on development – no-exceptions mode is constantly broken and must be tediously fixed before each release. Even when the tests are passing, it's likely that a lot of KJ's functionality realistically cannot be used in no-exceptions mode due to bugs and fragility. Today, I would strongly recommend against anyone using this mode except maybe for the most basic use of Cap'n Proto's serialization layer. Meanwhile, though, I'm honestly not sure if anyone uses this mode at all! In theory I would expect many people do, since many people choose to use C++ with exceptions disabled, but I've never actually received a single question or bug report related to it. It seems very likely that this was wasted effort all along. By removing support, we can simplify a lot of stuff and probably do releases more frequently going forward. + +* Similarly, we'll drop support for no-RTTI mode and other exotic modes that are a maintenance burden. + +* We may revise KJ's approach to reference counting, as the current design has proven to be unintuitive to many users. + +* We will fix a longstanding design flaw in `kj::AsyncOutputStream`, where EOF is currently signaled by destroying the stream. Instead, we'll add an explicit `end()` method that returns a Promise. Destroying the stream without calling `end()` will signal an erroneous disconnect. (There are several other aesthetic improvements I'd like to make to the KJ stream APIs as well.) + +* We may want to redesign several core I/O APIs to be a better fit for Linux's new-ish io_uring event notification paradigm. + +* The RPC implementation may switch to allowing cancellation by default. As discussed above, this is opt-in today, but in practice I find it's almost always desirable, and disallowing it can lead to subtle problems. + +* And so on. + +It's worth noting that at present, there is no plan to make any backwards-incompatible changes to the serialization format or RPC protocol. The changes being discussed only affect the C++ API. Applications written in other languages are completely unaffected by all this. + +It's likely that a formal 2.0 release will not happen for some time – probably a few years. I want to make sure we get through all the really big breaking changes we want to make, before we inflict update pain on most users. Of course, if you're willing to accept breakages, you can always track the `v2` branch. Cloudflare Workers releases from `v2` twice a week, so it should always be in good working order. diff --git a/libs/EXTERNAL/capnproto/doc/cxx.md b/libs/EXTERNAL/capnproto/doc/cxx.md index fd0ebe8670c..dcd8b4cca90 100644 --- a/libs/EXTERNAL/capnproto/doc/cxx.md +++ b/libs/EXTERNAL/capnproto/doc/cxx.md @@ -160,7 +160,7 @@ See the header `kj/exception.h` for details on how to register an exception call Cap'n Proto is built on top of a basic utility library called KJ. The two were actually developed together -- KJ is simply the stuff which is not specific to Cap'n Proto serialization, and may be -useful to others independently of Cap'n Proto. For now, the the two are distributed together. The +useful to others independently of Cap'n Proto. For now, the two are distributed together. The name "KJ" has no particular meaning; it was chosen to be short and easy-to-type. As of v0.3, KJ is distributed with Cap'n Proto but built as a separate library. You may need @@ -179,7 +179,7 @@ To use this code in your app, you must link against both `libcapnp` and `libkj`. flags. If you use [RPC](cxxrpc.html) (i.e., your schema defines [interfaces](language.html#interfaces)), -then you will additionally nead to link against `libcapnp-rpc` and `libkj-async`, or use the +then you will additionally need to link against `libcapnp-rpc` and `libkj-async`, or use the `capnp-rpc` `pkg-config` module. ### Setting a Namespace @@ -794,7 +794,7 @@ Here are some tips for using the C++ Cap'n Proto runtime most effectively: dead space. In the future, Cap'n Proto may be improved such that it can re-use dead space in a message. - However, this will only improve things, not fix them entirely: fragementation could still leave + However, this will only improve things, not fix them entirely: fragmentation could still leave dead space. ### Build Tips @@ -877,7 +877,7 @@ tips will apply. ## Lessons Learned from Protocol Buffers -The author of Cap'n Proto's C++ implementation also wrote (in the past) verison 2 of Google's +The author of Cap'n Proto's C++ implementation also wrote (in the past) version 2 of Google's Protocol Buffers. As a result, Cap'n Proto's implementation benefits from a number of lessons learned the hard way: diff --git a/libs/EXTERNAL/capnproto/doc/cxxrpc.md b/libs/EXTERNAL/capnproto/doc/cxxrpc.md index eb25acb8b50..3e55bcade5e 100644 --- a/libs/EXTERNAL/capnproto/doc/cxxrpc.md +++ b/libs/EXTERNAL/capnproto/doc/cxxrpc.md @@ -16,7 +16,7 @@ not yet implemented. ## Sample Code -The [Calculator example](https://github.com/sandstorm-io/capnproto/tree/master/c++/samples) implements +The [Calculator example](https://github.com/capnproto/capnproto/tree/master/c++/samples) implements a fully-functional Cap'n Proto client and server. ## KJ Concurrency Framework @@ -394,7 +394,7 @@ addresses. Additionally, a Unix domain socket can be specified as `unix:` follo and an abstract Unix domain socket can be specified as `unix-abstract:` followed by an identifier. For a more complete example, see the -[calculator client sample](https://github.com/sandstorm-io/capnproto/tree/master/c++/samples/calculator-client.c++). +[calculator client sample](https://github.com/capnproto/capnproto/tree/master/c++/samples/calculator-client.c++). ### Starting a server @@ -434,7 +434,7 @@ path name, and an abstract Unix domain socket can be specified as `unix-abstract an identifier. For a more complete example, see the -[calculator server sample](https://github.com/sandstorm-io/capnproto/tree/master/c++/samples/calculator-server.c++). +[calculator server sample](https://github.com/capnproto/capnproto/tree/master/c++/samples/calculator-server.c++). ## Debugging diff --git a/libs/EXTERNAL/capnproto/doc/encoding.md b/libs/EXTERNAL/capnproto/doc/encoding.md index 5eeba8fa31b..78a203249f1 100644 --- a/libs/EXTERNAL/capnproto/doc/encoding.md +++ b/libs/EXTERNAL/capnproto/doc/encoding.md @@ -268,7 +268,7 @@ A capability pointer, then, simply contains an index into the separate capabilit C (32 bits) = Index of the capability in the message's capability table. -In [rpc.capnp](https://github.com/sandstorm-io/capnproto/blob/master/c++/src/capnp/rpc.capnp), the +In [rpc.capnp](https://github.com/capnproto/capnproto/blob/master/c++/src/capnp/rpc.capnp), the capability table is encoded as a list of `CapDescriptors`, appearing along-side the message content in the `Payload` struct. However, some use cases may call for different approaches. A message that is built and consumed within the same process need not encode the capability table at all @@ -415,7 +415,7 @@ different limit if desired. Another reasonable strategy is to set the limit to s the original message size; however, most applications should place limits on overall message sizes anyway, so it makes sense to have one check cover both. -**List amplification:** A list of `Void` values or zero-size structs can have a very large element count while taking constant space on the wire. If the receiving application expects a list of structs, it will see these zero-sized elements as valid structs set to their default values. If it iterates through the list processing each element, it could spend a large amount of CPU time or other resources despite the message being small. To defend against this, the "traversal limit" should count a list of zero-sized elements as if each element were one word instead. This rule was introduced in the C++ implementation in [commit 1048706](https://github.com/sandstorm-io/capnproto/commit/104870608fde3c698483fdef6b97f093fc15685d). +**List amplification:** A list of `Void` values or zero-size structs can have a very large element count while taking constant space on the wire. If the receiving application expects a list of structs, it will see these zero-sized elements as valid structs set to their default values. If it iterates through the list processing each element, it could spend a large amount of CPU time or other resources despite the message being small. To defend against this, the "traversal limit" should count a list of zero-sized elements as if each element were one word instead. This rule was introduced in the C++ implementation in [commit 1048706](https://github.com/capnproto/capnproto/commit/104870608fde3c698483fdef6b97f093fc15685d). ### Stack overflow DoS attack diff --git a/libs/EXTERNAL/capnproto/doc/faq.md b/libs/EXTERNAL/capnproto/doc/faq.md index e3bf4becb55..9443e126838 100644 --- a/libs/EXTERNAL/capnproto/doc/faq.md +++ b/libs/EXTERNAL/capnproto/doc/faq.md @@ -197,15 +197,26 @@ Cap'n Proto may be layered on top of an existing encrypted transport, such as TL ### How do I report security bugs? -Please email [security@sandstorm.io](mailto:security@sandstorm.io). +Please email [kenton@cloudflare.com](mailto:kenton@cloudflare.com). ## Sandstorm ### How does Cap'n Proto relate to Sandstorm.io? -[Sandstorm.io](https://sandstorm.io) is an Open Source project and startup founded by Kenton, the author of Cap'n Proto. Cap'n Proto is owned and developed by Sandstorm the company and heavily used in Sandstorm the project. +[Sandstorm.io](https://sandstorm.io) is an Open Source project and startup founded by Kenton, the author of Cap'n Proto. Cap'n Proto was developed by Sandstorm the company and heavily used in Sandstorm the project. Sandstorm ceased most operations in 2017 and formally dissolved as a company in 2022, but the open source project continues to be developed by the community. ### How does Sandstorm use Cap'n Proto? See [this Sandstorm blog post](https://blog.sandstorm.io/news/2014-12-15-capnproto-0.5.html). +## Cloudflare + +### How does Cap'n Proto relate to Cloudflare? + +[Cloudflare Workers](https://workers.dev) is a next-generation cloud application platform. Kenton, the author of Cap'n Proto, is the lead engineer on the Workers project. Workers heavily uses Cap'n Proto in its implementation, and the Cloudflare Workers team are now the primarily developers and maintainers of Cap'n Proto's primary C++ implementation. + +### How does Cloudflare use Cap'n Proto? + +The Cloudflare Workers runtime is built on Cap'n Proto and it's associated C++ toolkit library, KJ. Cap'n Proto is used for a variety of things, such as communication between sandbox processes and their supervisors, as well between machines and datacenters, especially in the implementation of [Durable Objects](https://blog.cloudflare.com/introducing-workers-durable-objects/). + +Cloudflare has also [long used Cap'n Proto in its logging pipeline](http://www.thedotpost.com/2015/06/john-graham-cumming-i-got-10-trillion-problems-but-logging-aint-one) and [developed the Lua implementation of Cap'n Proto](https://blog.cloudflare.com/introducing-lua-capnproto-better-serialization-in-lua/) -- both of these actually predate Kenton joining the company. diff --git a/libs/EXTERNAL/capnproto/doc/index.md b/libs/EXTERNAL/capnproto/doc/index.md index 6545ebafa83..ae57b839792 100644 --- a/libs/EXTERNAL/capnproto/doc/index.md +++ b/libs/EXTERNAL/capnproto/doc/index.md @@ -51,7 +51,7 @@ Cap'n Proto generates classes with accessor methods that you use to traverse the Thus, Cap'n Proto checks the structural integrity of the message just like any other serialization protocol would. And, just like any other protocol, it is up to the app to check the validity of the content. -Cap'n Proto was built to be used in [Sandstorm.io](https://sandstorm.io), where security is a major concern. As of this writing, Cap'n Proto has not undergone a security review, therefore we suggest caution when handling messages from untrusted sources. That said, our response to security issues was once described by security guru Ben Laurie as ["the most awesome response I've ever had."](https://twitter.com/BenLaurie/status/575079375307153409) (Please report all security issues to [security@sandstorm.io](mailto:security@sandstorm.io).) +Cap'n Proto was built to be used in [Sandstorm.io](https://sandstorm.io), and is now heavily used in [Cloudflare Workers](https://workers.dev), two environments where security is a major concern. Cap'n Proto has undergone fuzzing and expert security review. Our response to security issues was once described by security guru Ben Laurie as ["the most awesome response I've ever had."](https://twitter.com/BenLaurie/status/575079375307153409) (Please report all security issues to [kenton@cloudflare.com](mailto:kenton@cloudflare.com).) **_Are there other advantages?_** @@ -90,7 +90,7 @@ version 2, which is the version that Google released open source. Cap'n Proto is years of experience working on Protobufs, listening to user feedback, and thinking about how things could be done better. -Note that I no longer work for Google. Cap'n Proto is not, and never has been, affiliated with Google; in fact, it is a property of [Sandstorm.io](https://sandstorm.io), of which I am co-founder. +Note that I no longer work for Google. Cap'n Proto is not, and never has been, affiliated with Google. **_OK, how do I get started?_** diff --git a/libs/EXTERNAL/capnproto/doc/install.md b/libs/EXTERNAL/capnproto/doc/install.md index ca94e4219ca..d911ee15695 100644 --- a/libs/EXTERNAL/capnproto/doc/install.md +++ b/libs/EXTERNAL/capnproto/doc/install.md @@ -21,13 +21,13 @@ This package is licensed under the [MIT License](http://opensource.org/licenses/ Cap'n Proto makes extensive use of C++14 language features. As a result, it requires a relatively new version of a well-supported compiler. The minimum versions are: -* GCC 5.0 -* Clang 5.0 -* Visual C++ 2017 +* GCC 7.0 +* Clang 6.0 +* Visual C++ 2019 If your system's default compiler is older that the above, you will need to install a newer compiler and set the `CXX` environment variable before trying to build Cap'n Proto. For example, -after installing GCC 5, you could set `CXX=g++-5` to use this compiler. +after installing GCC 7, you could set `CXX=g++-7` to use this compiler. ### Supported Operating Systems @@ -37,11 +37,10 @@ as well as on Windows. We test every Cap'n Proto release on the following platfo * Android * Linux * Mac OS X -* Windows - Cygwin * Windows - MinGW-w64 * Windows - Visual C++ -**Windows users:** Cap'n Proto requires Visual Studio 2017 or newer. All features +**Windows users:** Cap'n Proto requires Visual Studio 2019 or newer. All features of Cap'n Proto -- including serialization, dynamic API, RPC, and schema parser -- are now supported. **Mac OS X users:** You should use the latest Xcode with the Xcode command-line @@ -56,9 +55,9 @@ package from [Apple](https://developer.apple.com/downloads/) or compiler builds You may download and install the release version of Cap'n Proto like so: -
curl -O https://capnproto.org/capnproto-c++-0.9.1.tar.gz
-tar zxf capnproto-c++-0.9.1.tar.gz
-cd capnproto-c++-0.9.1
+
curl -O https://capnproto.org/capnproto-c++-1.0.2.tar.gz
+tar zxf capnproto-c++-1.0.2.tar.gz
+cd capnproto-c++-1.0.2
 ./configure
 make -j6 check
 sudo make install
@@ -84,7 +83,7 @@ If you download directly from Git, you will need to have the GNU autotools -- [automake](http://www.gnu.org/software/automake/), and [libtool](http://www.gnu.org/software/libtool/) -- installed. - git clone https://github.com/sandstorm-io/capnproto.git + git clone -b master https://github.com/capnproto/capnproto.git cd capnproto/c++ autoreconf -i ./configure @@ -97,15 +96,15 @@ If you download directly from Git, you will need to have the GNU autotools -- 1. Download Cap'n Proto Win32 build: -
https://capnproto.org/capnproto-c++-win32-0.9.1.zip
+
https://capnproto.org/capnproto-c++-win32-1.0.2.zip
-2. Find `capnp.exe`, `capnpc-c++.exe`, and `capnpc-capnp.exe` under `capnproto-tools-win32-0.9.1` in +2. Find `capnp.exe`, `capnpc-c++.exe`, and `capnpc-capnp.exe` under `capnproto-tools-win32-1.0.2` in the zip and copy them somewhere. 3. If your `.capnp` files will import any of the `.capnp` files provided by the core project, or if you use the `stream` keyword (which implicitly imports `capnp/stream.capnp`), then you need to put those files somewhere where the capnp compiler can find them. To do this, copy the - directory `capnproto-c++-0.9.1/src` to the location of your choice, then make sure to pass the + directory `capnproto-c++-1.0.2/src` to the location of your choice, then make sure to pass the flag `-I ` to `capnp` when you run it. If you don't care about C++ support, you can stop here. The compiler exe can be used with plugins @@ -113,16 +112,16 @@ provided by projects implementing Cap'n Proto in other languages. If you want to use Cap'n Proto in C++ with Visual Studio, do the following: -1. Make sure that you are using Visual Studio 2017 or newer, with all updates installed. Cap'n +1. Make sure that you are using Visual Studio 2019 or newer, with all updates installed. Cap'n Proto uses C++14 language features that did not work in previous versions of Visual Studio, and the updates include many bug fixes that Cap'n Proto requires. -2. Install [CMake](http://www.cmake.org/) version 3.1 or later. +2. Install [CMake](http://www.cmake.org/) version 3.16 or later. -3. Use CMake to generate Visual Studio project files under `capnproto-c++-0.9.1` in the zip file. +3. Use CMake to generate Visual Studio project files under `capnproto-c++-1.0.2` in the zip file. You can use the CMake UI for this or run this shell command: - cmake -G "Visual Studio 15 2017" + cmake -G "Visual Studio 16 2019" 3. Open the "Cap'n Proto" solution in Visual Studio. diff --git a/libs/EXTERNAL/capnproto/doc/language.md b/libs/EXTERNAL/capnproto/doc/language.md index 034e854c460..5b638ee840f 100644 --- a/libs/EXTERNAL/capnproto/doc/language.md +++ b/libs/EXTERNAL/capnproto/doc/language.md @@ -581,7 +581,9 @@ struct Foo { The above imports specify relative paths. If the path begins with a `/`, it is absolute -- in this case, the `capnp` tool searches for the file in each of the search path directories specified -with `-I`. +with `-I`, appending the path you specify to the path given to the `-I` flag. So, for example, +if you ran `capnp` with `-Ifoo/bar`, and the import statement is `import "/baz/qux.capnp"`, then +the compiler would open the file `foo/bar/baz/qux.capnp`. ### Annotations @@ -606,7 +608,7 @@ struct MyType $foo("bar") { {% endhighlight %} The possible targets for an annotation are: `file`, `struct`, `field`, `union`, `group`, `enum`, -`enumerant`, `interface`, `method`, `parameter`, `annotation`, `const`. +`enumerant`, `interface`, `method`, `param`, `annotation`, `const`. You may also specify `*` to cover them all. {% highlight capnp %} @@ -736,7 +738,15 @@ without changing the [canonical](encoding.html#canonicalization) encoding of a m * A field can be moved into a group or a union, as long as the group/union and all other fields within it are new. In other words, a field can be replaced with a group or union containing an - equivalent field and some new fields. + equivalent field and some new fields. Note that when creating a union this way, this particular + change is not fully forwards-compatible: if you create a message where one of the union's new + fields are set, and the message is read by an old program that dosen't know about the union, then + it may expect the original field to be present, and if it tries to read that field, may see a + garbage value or throw an exception. To avoid this problem, make sure to only use the new union + members when talking to programs that know about the union. This caveat only applies when moving + an existing field into a new union; adding new fields to an existing union does not create a + problem, because existing programs should already know to check the union's tag (although they + may or may not behave reasonably when the tag has a value they don't recognize). * A non-generic type can be made [generic](#generic-types), and new generic parameters may be added to an existing generic type. Other types used inside the body of the newly-generic type can diff --git a/libs/EXTERNAL/capnproto/doc/otherlang.md b/libs/EXTERNAL/capnproto/doc/otherlang.md index e3064927679..c7c156783b8 100644 --- a/libs/EXTERNAL/capnproto/doc/otherlang.md +++ b/libs/EXTERNAL/capnproto/doc/otherlang.md @@ -14,9 +14,9 @@ project's documentation for details. ##### Serialization + RPC * [C++](cxx.html) by [@kentonv](https://github.com/kentonv) -* [C# (.NET Core)](https://github.com/c80k/capnproto-dotnetcore) by [@c80k](https://github.com/c80k) +* [C#](https://github.com/c80k/capnproto-dotnetcore) by [@c80k](https://github.com/c80k) * [Erlang](http://ecapnp.astekk.se/) by [@kaos](https://github.com/kaos) -* [Go](https://github.com/zombiezen/go-capnproto2) by [@zombiezen](https://github.com/zombiezen) (forked from [@glycerine](https://github.com/glycerine)'s serialization-only version, below) +* [Go](https://github.com/capnproto/go-capnp) currently maintained by [@zenhack](https://github.com/zenhack) and [@lthibault](https://github.com/lthibault) * [Haskell](https://github.com/zenhack/haskell-capnp) by [@zenhack](https://github.com/zenhack) * [JavaScript (Node.js only)](https://github.com/capnproto/node-capnp) by [@kentonv](https://github.com/kentonv) * [OCaml](https://github.com/capnproto/capnp-ocaml) by [@pelzlpj](https://github.com/pelzlpj) with [RPC](https://github.com/mirage/capnp-rpc) by [@talex5](https://github.com/talex5) @@ -25,9 +25,9 @@ project's documentation for details. ##### Serialization only -* [C](https://github.com/opensourcerouting/c-capnproto) by [OpenSourceRouting](https://www.opensourcerouting.org/) / [@eqvinox](https://github.com/eqvinox) (originally by [@jmckaskill](https://github.com/jmckaskill)) +* [C](https://github.com/opensourcerouting/c-capnproto) by [OpenSourceRouting](https://www.opensourcerouting.org/) / [@eqvinox](https://github.com/eqvinox) (originally by [@jmckaskill](https://github.com/jmckaskill)) (no longer maintained) + * [Forked and maintained](https://gitlab.com/dkml/ext/c-capnproto) by [@jonahbeckford](https://github.com/jonahbeckford) * [D](https://github.com/capnproto/capnproto-dlang) by [@ThomasBrixLarsen](https://github.com/ThomasBrixLarsen) -* [Go](https://github.com/glycerine/go-capnproto) by [@glycerine](https://github.com/glycerine) (originally by [@jmckaskill](https://github.com/jmckaskill)) * [Java](https://github.com/capnproto/capnproto-java/) by [@dwrensha](https://github.com/dwrensha) * [JavaScript](https://github.com/capnp-js/plugin/) by [@popham](https://github.com/popham) * [JavaScript](https://github.com/jscheid/capnproto-js) (older, abandoned) by [@jscheid](https://github.com/jscheid) @@ -72,7 +72,7 @@ then hands the parse tree off to another binary -- known as a "plugin" -- which Plugins are independent executables (written in any language) which read a description of the schema from standard input and then generate the necessary code. The description is itself a Cap'n Proto message, defined by -[schema.capnp](https://github.com/sandstorm-io/capnproto/blob/master/c%2B%2B/src/capnp/schema.capnp). +[schema.capnp](https://github.com/capnproto/capnproto/blob/master/c%2B%2B/src/capnp/schema.capnp). Specifically, the plugin receives a `CodeGeneratorRequest`, using [standard serialization](encoding.html#serialization-over-a-stream) (not packed). (Note that installing the C++ runtime causes schema.capnp to be placed in @@ -100,8 +100,8 @@ If the user specifies an output directory, the compiler will run the plugin with as the working directory, so you do not need to worry about this. For examples of plugins, take a look at -[capnpc-capnp](https://github.com/sandstorm-io/capnproto/blob/master/c%2B%2B/src/capnp/compiler/capnpc-capnp.c%2B%2B) -or [capnpc-c++](https://github.com/sandstorm-io/capnproto/blob/master/c%2B%2B/src/capnp/compiler/capnpc-c%2B%2B.c%2B%2B). +[capnpc-capnp](https://github.com/capnproto/capnproto/blob/master/c%2B%2B/src/capnp/compiler/capnpc-capnp.c%2B%2B) +or [capnpc-c++](https://github.com/capnproto/capnproto/blob/master/c%2B%2B/src/capnp/compiler/capnpc-c%2B%2B.c%2B%2B). ### Supporting Dynamic Languages diff --git a/libs/EXTERNAL/capnproto/doc/push-site.sh b/libs/EXTERNAL/capnproto/doc/push-site.sh index 958c1013dd8..48f89e62c01 100755 --- a/libs/EXTERNAL/capnproto/doc/push-site.sh +++ b/libs/EXTERNAL/capnproto/doc/push-site.sh @@ -9,7 +9,7 @@ if grep 'localhost:4000' *.md _posts/*.md; then fi if [ "x$(git status --porcelain)" != "x" ]; then - echo -n "git repo has uncommited changes. Continue anyway? (y/N) " >&2 + echo -n "git repo has uncommitted changes. Continue anyway? (y/N) " >&2 read -n 1 YESNO echo >&2 if [ "x$YESNO" != xy ]; then diff --git a/libs/EXTERNAL/capnproto/doc/roadmap.md b/libs/EXTERNAL/capnproto/doc/roadmap.md index 097828609e5..6b7fcc1e4d1 100644 --- a/libs/EXTERNAL/capnproto/doc/roadmap.md +++ b/libs/EXTERNAL/capnproto/doc/roadmap.md @@ -50,7 +50,7 @@ these will actually happen; as always, real work is driven by real-world needs. to each struct type. The POCS type would use traditional memory allocation, thus would not support zero-copy, but would support a more traditional and easy-to-use C++ API, including the ability to mutate the object over time without convoluted memory management. POCS types - could be extracted from an inserted into messages with a single copy, allowing them to be + could be extracted from and inserted into messages with a single copy, allowing them to be used easily in non-performance-critical code. * **Multi-threading:** It should be made easy to assign different Cap'n Proto RPC objects to different threads and have them be able to safely call each other. Each thread would still diff --git a/libs/EXTERNAL/capnproto/doc/rpc.md b/libs/EXTERNAL/capnproto/doc/rpc.md index 1c6dcd81fb4..9ef2a49fe4c 100644 --- a/libs/EXTERNAL/capnproto/doc/rpc.md +++ b/libs/EXTERNAL/capnproto/doc/rpc.md @@ -142,7 +142,7 @@ performs as well as we can possibly hope for. #### Example code -[The calculator example](https://github.com/sandstorm-io/capnproto/blob/master/c++/samples/calculator-client.c++) +[The calculator example](https://github.com/capnproto/capnproto/blob/master/c++/samples/calculator-client.c++) uses promise pipelining. Take a look at the client side in particular. ### Distributed Objects @@ -244,7 +244,7 @@ stream protocol, it can easily be layered on top of SSL/TLS or other such protoc The Cap'n Proto RPC protocol is defined in terms of Cap'n Proto serialization schemas. The documentation is inline. See -[rpc.capnp](https://github.com/sandstorm-io/capnproto/blob/master/c++/src/capnp/rpc.capnp). +[rpc.capnp](https://github.com/capnproto/capnproto/blob/master/c++/src/capnp/rpc.capnp). Cap'n Proto's RPC protocol is based heavily on [CapTP](http://www.erights.org/elib/distrib/captp/index.html), the distributed capability protocol diff --git a/libs/EXTERNAL/capnproto/doc/slides-2017.05.18/index.md b/libs/EXTERNAL/capnproto/doc/slides-2017.05.18/index.md index a0e09373313..ad7ec6ebb68 100644 --- a/libs/EXTERNAL/capnproto/doc/slides-2017.05.18/index.md +++ b/libs/EXTERNAL/capnproto/doc/slides-2017.05.18/index.md @@ -531,7 +531,7 @@ getAll @3 (page :UInt32 = 0 $httpQuery) $http(method = get); # GET /?page= # Query is optional. -# JSAN (JSON array) repsonse body. +# JSAN (JSON array) response body. {% endhighlight %} diff --git a/libs/EXTERNAL/capnproto/highlighting/qtcreator/capnp.xml b/libs/EXTERNAL/capnproto/highlighting/qtcreator/capnp.xml index bb5c7812c4e..e675e5b0632 100644 --- a/libs/EXTERNAL/capnproto/highlighting/qtcreator/capnp.xml +++ b/libs/EXTERNAL/capnproto/highlighting/qtcreator/capnp.xml @@ -112,7 +112,7 @@ of these, like "keyword" and "type", could be mapped to dsKeyword and dsDataType, but there's a chance the user has mapped the colors for those things to things that would conflict with the manually-defined colors here, which would probably be even more annoying - than having the colors be inconsitent from other languages. So, I use manual colors for + than having the colors be inconsistent from other languages. So, I use manual colors for everything, except comments, which I figure are less likely to have this problem. --> diff --git a/libs/EXTERNAL/capnproto/kjdoc/tour.md b/libs/EXTERNAL/capnproto/kjdoc/tour.md index 06d9be08621..922e76941f2 100644 --- a/libs/EXTERNAL/capnproto/kjdoc/tour.md +++ b/libs/EXTERNAL/capnproto/kjdoc/tour.md @@ -97,7 +97,7 @@ for (auto i: kj::indices(foo)) { `kj::downcast(value)` is equivalent to `static_cast(value)`, except that when compiled in debug mode with RTTI available, a runtime check (`dynamic_cast`) will be performed to verify that `value` really has type `T`. Use this in cases where you are casting a base type to a derived type, and you are confident that the object is actually an instance of the derived type. The debug-mode check will help you catch bugs. -`kj::dynamicDowncastIfAvailable(value)` is like `dynamic_cast(value)` with two differences. First, it returns `kj::Maybe` instead of `T*`. Second, if the program is compiled without RTTI enabled, the function always returns null. This function is intended to be used to implement optimizations, where the code can do something smarter if `value` happens to be of some specific type -- but if RTTI is not available, it is safe to skip the optimization. See [KJ idiomatic use of dynamic_cast](style-guide.md#dynamic_cast) for more background. +`kj::dynamicDowncastIfAvailable(value)` is like `dynamic_cast(value)` with two differences. First, it returns `kj::Maybe` instead of `T*`. Second, if the program is compiled without RTTI enabled, the function always returns null. This function is intended to be used to implement optimizations, where the code can do something smarter if `value` happens to be of some specific type -- but if RTTI is not available, it is safe to skip the optimization. See [KJ idiomatic use of dynamic_cast](../style-guide.md#dynamic_cast) for more background. ### Min/max, numeric limits, and special floats @@ -120,7 +120,7 @@ These functions should almost never be used in high-level code. They are intende ## Ownership and memory management -KJ style makes heavy use of [RAII](style-guide.md#raii-resource-acquisition-is-initialization). KJ-based code should never use `new` and `delete` directly. Instead, use the utilities in this section to manage memory in a RAII way. +KJ style makes heavy use of [RAII](../style-guide.md#raii-resource-acquisition-is-initialization). KJ-based code should never use `new` and `delete` directly. Instead, use the utilities in this section to manage memory in a RAII way. ### Owned pointers, heap allocation, and disposers @@ -135,7 +135,7 @@ However, a `kj::Own` does not necessarily refer to a heap object. A `kj::Own` is Some example uses of disposers include: * `kj::fakeOwn(ref)` returns a `kj::Own` that points to `ref` but doesn't actually destroy it. This is useful when you know for sure that `ref` will outlive the scope of the `kj::Own`, and therefore heap allocation is unnecessary. This is common in cases where, for example, the `kj::Own` is being passed into an object which itself will be destroyed before `ref` becomes invalid. It also makes sense when `ref` is actually a static value or global that lives forever. -* `kj::refcounted(args...)` allocates a `T` which uses reference counting. It returns a `kj::Own` that represents one reference to the object. Additional references can be created by calling `kj::addRef(*ptr)`. The object is destroyed when no more `kj::Own`s exist pointing at it. Note that `T` must be a subclass of `kj::Refcounted`. If references may be shared across threads, then atomic refcounting must be used; use `kj::atomicRefcounted(args...)` and inherit `kj::AtomicRefcounted`. Reference counting should be using sparingly; see [KJ idioms around reference counting](style-guide.md#reference-counting) for a discussion of when it should be used and why it is designed the way it is. +* `kj::refcounted(args...)` allocates a `T` which uses reference counting. It returns a `kj::Own` that represents one reference to the object. Additional references can be created by calling `kj::addRef(*ptr)`. The object is destroyed when no more `kj::Own`s exist pointing at it. Note that `T` must be a subclass of `kj::Refcounted`. If references may be shared across threads, then atomic refcounting must be used; use `kj::atomicRefcounted(args...)` and inherit `kj::AtomicRefcounted`. Reference counting should be using sparingly; see [KJ idioms around reference counting](../style-guide.md#reference-counting) for a discussion of when it should be used and why it is designed the way it is. * `kj::attachRef(ref, args...)` returns a `kj::Own` pointing to `ref` that actually owns `args...`, so that when the `kj::Own` goes out-of-scope, the other arguments are destroyed. Typically these arguments are themselves `kj::Own`s or other pass-by-move values that themselves own the object referenced by `ref`. `kj::attachVal(value, args...)` is similar, where `value` is a pass-by-move value rather than a reference; a copy of it will be allocated on the heap. Finally, `ownPtr.attach(args...)` returns a new `kj::Own` pointing to the same value that `ownPtr` pointed to, but such that `args...` are owned as well and will be destroyed together. Attachments are always destroyed after the thing they are attached to. * `kj::SpaceFor` contains enough space for a value of type `T`, but does not construct the value until its `construct(args...)` method is called. That method returns an `kj::Own`, whose disposer destroys the value. `kj::SpaceFor` is thus a safer way to perform manual construction compared to invoking `kj::ctor()` and `kj::dtor()`. @@ -236,6 +236,8 @@ KJ_IF_MAYBE(j, maybeJ) { Note that `KJ_IF_MAYBE` forces you to think about the null case. This differs from `std::optional`, which can be dereferenced using `*`, resulting in undefined behavior if the value is null. +Similarly, `map()` and `orDefault()` allow transforming and retrieving the stored value in a safe manner without complex control flows. + Performance nuts will be interested to know that `kj::Maybe` and `kj::Maybe>` are both optimized such that they take no more space than their underlying pointer type, using a literal null pointer to indicate nullness. For other types of `T`, `kj::Maybe` must maintain an extra boolean and so is somewhat larger than `T`. ### Variant types @@ -278,7 +280,7 @@ typedef kj::OneOf State; `kj::Function` represents a callable function with the given signature. A `kj::Function` can be initialized from any callable object, such as a lambda, function pointer, or anything with `operator()`. `kj::Function` is useful when you want to write an API that accepts a lambda callback, without defining the API itself as a template. `kj::Function` supports move semantics. -`kj::ConstFunction` is like `kj::Function`, but is used to indicate that the function should be safe to call from multiple threads. (See [KJ idioms around constness and thread-safety](style-guide.md#constness).) +`kj::ConstFunction` is like `kj::Function`, but is used to indicate that the function should be safe to call from multiple threads. (See [KJ idioms around constness and thread-safety](../style-guide.md#constness).) A special optimization type, `kj::FunctionParam`, is like `kj::Function` but designed to be used specifically as the type of a callback parameter to some other function where that callback is only called synchronously; i.e., the callback won't be called anymore after the outer function returns. Unlike `kj::Function`, a `kj::FunctionParam` can be constructed entirely on the stack, with no heap allocation. @@ -300,7 +302,7 @@ KJ's tree-based containers use a b-tree design for better memory locality than t ## Debugging and Observability -KJ believes that there is no such thing as bug-free code. Instead, we must expect that our code will go wrong, and try to extract as much information as possible when it does. To that end, KJ provides powerful assertion macros designed for observability. (Be sure also to read about [KJ's exception philosophy](style-guide.md#exceptions); this section describes the actual APIs involved.) +KJ believes that there is no such thing as bug-free code. Instead, we must expect that our code will go wrong, and try to extract as much information as possible when it does. To that end, KJ provides powerful assertion macros designed for observability. (Be sure also to read about [KJ's exception philosophy](../style-guide.md#exceptions); this section describes the actual APIs involved.) ### Assertions @@ -400,7 +402,7 @@ On Windows, two similar macros are available based on Windows API calling conven ### Alternate exception types -As described in [KJ's exception philosophy](style-guide.md#exceptions), KJ supports a small set of exception types. Regular assertions throw `FAILED` exceptions. `KJ_SYSCALL` usually throws `FAILED`, but identifies certain error codes as `DISCONNECTED` or `OVERLOADED`. For example, `ECONNRESET` is clearly a `DISCONNECTED` exception. +As described in [KJ's exception philosophy](../style-guide.md#exceptions), KJ supports a small set of exception types. Regular assertions throw `FAILED` exceptions. `KJ_SYSCALL` usually throws `FAILED`, but identifies certain error codes as `DISCONNECTED` or `OVERLOADED`. For example, `ECONNRESET` is clearly a `DISCONNECTED` exception. If you wish to manually construct and throw a different exception type, you may use `KJ_EXCEPTION`: @@ -526,7 +528,7 @@ This section describes KJ APIs that control process execution and low-level inte `kj::Thread` creates a thread in which the lambda passed to `kj::Thread`'s constructor will be executed. `kj::Thread`'s destructor waits for the thread to exit before continuing, and rethrows any exception that had been thrown from the thread's main function -- unless the thread's `.detach()` method has been called, in which case `kj::Thread`'s destructor does nothing. -`kj::MutexGuarded` holds an instance of `T` that is protected by a mutex. In order to access the protected value, you must first create a lock. `.lockExclusive()` returns `kj::Locked` which can be used to access the underlying value. `.lockShared()` returns `kj::Locked`, [using constness to enforce thread-safe read-only access](style-guide.md#constness) so that multiple threads can take the lock concurrently. In this way, KJ mutexes make it difficult to forget to take a lock before accessing the protected object. +`kj::MutexGuarded` holds an instance of `T` that is protected by a mutex. In order to access the protected value, you must first create a lock. `.lockExclusive()` returns `kj::Locked` which can be used to access the underlying value. `.lockShared()` returns `kj::Locked`, [using constness to enforce thread-safe read-only access](../style-guide.md#constness) so that multiple threads can take the lock concurrently. In this way, KJ mutexes make it difficult to forget to take a lock before accessing the protected object. `kj::Locked` has a method `.wait(cond)` which temporarily releases the lock and waits, taking the lock back as soon as `cond(value)` evaluates true. This provides a much cleaner and more readable interface than traditional conditional variables. @@ -908,9 +910,11 @@ The opposite of forking promises is joining promises. There are two types of joi For an exclusive join, use `promise.exclusiveJoin(kj::mv(otherPromise))`. The two promises must return the same type. The result is a promise that returns whichever result is produced first, and cancels the other promise at that time. (To exclusively join more than two promises, call `.exclusiveJoin()` multiple times in a chain.) -To perform an inclusive join, use `kj::joinPromises()`. This turns a `kj::Array>` into a `kj::Promise>`. However, note that `kj::joinPromises()` has a couple common gotchas: +To perform an inclusive join, use `kj::joinPromises()` or `kj::joinPromisesFailFast()`. These turn a `kj::Array>` into a `kj::Promise>`. However, note that `kj::joinPromises()` has a couple common gotchas: * Trailing continuations on the promises passed to `kj::joinPromises()` are evaluated lazily after all the promises become ready. Use `.eagerlyEvaluate()` on each one to force trailing continuations to happen eagerly. (See earlier discussion under "Background Tasks".) -* If any promise in the array rejects, the exception will be held until all other promises have completed (or rejected), and only then will the exception propagate. In practice we've found that most uses of `kj::joinPromises()` would prefer "exclusive" or "fail-fast" behavior in the case of an exception, but as of this writing we have not yet introduced a function that does this. +* If any promise in the array rejects, the exception will be held until all other promises have completed (or rejected), and only then will the exception propagate. In practice we've found that most uses of `kj::joinPromises()` would prefer "exclusive" or "fail-fast" behavior in the case of an exception. + +`kj::joinPromisesFailFast()` addresses the gotchas described above: promise continuations are evaluated eagerly, and if any promise results in an exception, the join promise is immediately rejected with that exception. ### Threads @@ -938,7 +942,40 @@ kj::Promise promise = **CAUTION:** Fibers produce attractive-looking code, but have serious drawbacks. Every fiber must allocate a new call stack, which is typically rather large. The above example allocates a 64kb stack, which is the _minimum_ supported size. Some programs and libraries expect to be able to allocate megabytes of data on the stack. On modern Linux systems, a default stack size of 8MB is typical. Stack space is allocated lazily on page faults, but just setting up the memory mapping is much more expensive than a typical `malloc()`. If you create lots of fibers, you should use `kj::FiberPool` to reduce allocation costs -- but while this reduces allocation overhead, it will increase memory usage. -Because of this, fibers should not be used just to make code look nice (C++20's `co_await`, which KJ will soon support, is a better way to do that). Instead, the main use case for fibers is to be able to call into existing libraries that are not designed to operate in an asynchronous way. For example, say you find a library that performs stream I/O, and lets you provide your own `read()`/`write()` implementations, but expects those implementations to operate in a blocking fashion. With fibers, you can use such a library within the asynchronous KJ event loop. +Because of this, fibers should not be used just to make code look nice (C++20's `co_await`, described below, is a better way to do that). Instead, the main use case for fibers is to be able to call into existing libraries that are not designed to operate in an asynchronous way. For example, say you find a library that performs stream I/O, and lets you provide your own `read()`/`write()` implementations, but expects those implementations to operate in a blocking fashion. With fibers, you can use such a library within the asynchronous KJ event loop. + +### Coroutines + +C++20 brings us coroutines, which, like fibers, allow code to be written in a synchronous / blocking style while running inside the KJ event loop. Coroutines accomplish this with a different strategy than fibers: instead of running code on an alternate stack and switching stacks on suspension, coroutines save local variables and temporary objects in a heap-allocated "coroutine frame" and always unwind the stack on suspension. + +A C++ function is a KJ coroutine if it follows these two rules: +- The function returns a `kj::Promise`. +- The function uses a `co_await` or `co_return` keyword in its implementation. + +```c++ +kj::Promise aCoroutine() { + int i = co_await someAsyncFunc(); + i += co_await anotherAsyncFunc(); + co_return i; +}); + +// Call like any regular promise-returning function. +auto promise = aCoroutine(); +``` + +The promise returned by a coroutine owns the coroutine frame. If you destroy the promise, any objects alive in the frame will be destroyed, and the frame freed, thus cancellation works exactly as you'd expect. + +There are some caveats one should be aware of while writing coroutines: +- Lambda captures **do not** live inside of the coroutine frame, meaning lambda objects must outlive any coroutine Promises they return, or else the coroutine will encounter dangling references to captured objects. This is a defect in the C++ standard: https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rcoro-capture. To safely use a capturing lambda as a coroutine, first wrap it using `kj::coCapture([captures]() { ... })`, then invoke that object. +- Holding a mutex lock across a `co_await` is almost always a bad idea, with essentially the same problems as holding a lock while calling `promise.wait(waitScope)`. This would cause the coroutine to hold the lock for however many turns of the event loop is required to drive the coroutine to release the lock; if I/O is involved, this could cause significant problems. Additionally, a reentrant call to the coroutine on the same thread would deadlock. Instead, if a coroutine must temporarily hold a lock, always keep the lock in a new lexical scope without any `co_await`. +- Attempting to define (and use) a variable-length array will cause a compile error, because the size of coroutine frames must be knowable at compile-time. The error message that clang emits for this, "Coroutines cannot handle non static allocas yet", suggests this may be relaxed in the future. + +As of this writing, KJ supports C++20 coroutines and Coroutines TS coroutines, the latter being an experimental precursor to C++20 coroutines. They are functionally the same thing, but enabled with different compiler/linker flags: + +- Enable C++20 coroutines by requesting that language standard from your compiler. +- Enable Coroutines TS coroutines with `-fcoroutines-ts` in C++17 clang, and `/await` in MSVC. + +KJ prefers C++20 coroutines when both implementations are available. ### Unit testing tips @@ -964,6 +1001,8 @@ KJ_ASSERT(promise.poll(waitScope)); promise.wait(waitScope); ``` +Sometimes, you may need to ensure that some promise has completed that you don't have a reference to, so you can observe that some side effect has occurred. You can use `waitScope.poll()` to flush the event loop without waiting for a specific promise to complete. + ## System I/O ### Async I/O @@ -1008,7 +1047,7 @@ KJ provides a time library in `kj/time.h` which uses the type system to enforce `kj::Clock` is a simple interface whose `now()` method returns the current `kj::Date`. `kj::MonotonicClock` is a similar interface returning a `kj::TimePoint`, but with the guarantee that times returned always increase (whereas a `kj::Clock` might go "back in time" if the user manually modifies their system clock). -`kj::systemCoarseCalendarClock()`, `kj::systemPreciseCalendarClock()`, `kj::systemCoarseMonotonicClock()`, `kj::systemPreciseMonotonicClock()` are global functions that return implementations of `kj::Clock` or `kJ::MonotonicClock` based on sytem time. +`kj::systemCoarseCalendarClock()`, `kj::systemPreciseCalendarClock()`, `kj::systemCoarseMonotonicClock()`, `kj::systemPreciseMonotonicClock()` are global functions that return implementations of `kj::Clock` or `kJ::MonotonicClock` based on system time. `kj::Timer` provides an async (promise-based) interface to wait for a specified time to pass. A `kj::Timer` is provided via `kj::AsyncIoProvider`, constructed using `kj::setupAsyncIo()` (see earlier discussion on async I/O). diff --git a/libs/EXTERNAL/capnproto/release.sh b/libs/EXTERNAL/capnproto/release.sh index f9f104e776d..4225682956e 100755 --- a/libs/EXTERNAL/capnproto/release.sh +++ b/libs/EXTERNAL/capnproto/release.sh @@ -2,7 +2,7 @@ set -euo pipefail -if [ "$1" != "package" ]; then +if [ "$1" != "package" ] && [ "$1" != "bump-major" ]; then if (grep -r KJ_DBG c++/src | egrep -v '/debug(-test)?[.]' | grep -v 'See KJ_DBG\.$'); then echo '*** Error: There are instances of KJ_DBG in the code.' >&2 exit 1 @@ -50,7 +50,7 @@ update_version() { c++/src/capnp/common.h local NEW_COMBINED=$(( ${NEW_ARR[0]} * 1000000 + ${NEW_ARR[1]} * 1000 + ${NEW_ARR[2]:-0 })) - doit sed -i -re "s/^#if CAPNP_VERSION != [0-9]*\$/#if CAPNP_VERSION != $NEW_COMBINED/g" \ + doit sed -i -re "s/^#elif CAPNP_VERSION != [0-9]*\$/#elif CAPNP_VERSION != $NEW_COMBINED/g" \ c++/src/*/*.capnp.h c++/src/*/*/*.capnp.h doit git commit -a -m "Set $BRANCH_DESC version to $NEW." @@ -146,6 +146,14 @@ done_banner() { BRANCH=$(git rev-parse --abbrev-ref HEAD) case "${1-}:$BRANCH" in + bump-major:* ) + echo "Bump major version number on HEAD." + HEAD_VERSION=$(get_version '^[0-9]+[.][0-9]+-dev$') + OLD_MAJOR=$(echo $HEAD_VERSION | cut -d. -f1) + NEW_VERSION=$(( OLD_MAJOR + 1 )).0-dev + update_version $HEAD_VERSION $NEW_VERSION "mainline" + ;; + # ====================================================================================== candidate:master ) echo "New major release." diff --git a/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-0-c++-integer-overflow.md b/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-0-c++-integer-overflow.md index d25b2a19fcb..300647e2fe1 100644 --- a/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-0-c++-integer-overflow.md +++ b/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-0-c++-integer-overflow.md @@ -35,7 +35,7 @@ Fixed in - Unix: https://capnproto.org/capnproto-c++-0.4.1.1.tar.gz - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/f343f0dbd0a2e87f17cd74f14186ed73e3fbdbfa +[0]: https://github.com/capnproto/capnproto/commit/f343f0dbd0a2e87f17cd74f14186ed73e3fbdbfa Details ======= @@ -97,6 +97,6 @@ following preventative measures going forward: I am pleased that measures 1, 2, and 3 all detected this bug, suggesting that they have a high probability of catching any similar bugs. -[1]: https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-0-all-cpu-amplification.md -[2]: https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md +[1]: https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-0-all-cpu-amplification.md +[2]: https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md [3]: https://capnproto.org/news/2015-03-02-security-advisory-and-integer-overflow-protection.html diff --git a/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-1-c++-integer-underflow.md b/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-1-c++-integer-underflow.md index 970f8b9aec6..06a3cd2f402 100644 --- a/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-1-c++-integer-underflow.md +++ b/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-1-c++-integer-underflow.md @@ -37,7 +37,7 @@ Fixed in - Unix: https://capnproto.org/capnproto-c++-0.4.1.1.tar.gz - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/26bcceda72372211063d62aab7e45665faa83633 +[0]: https://github.com/capnproto/capnproto/commit/26bcceda72372211063d62aab7e45665faa83633 Details ======= @@ -106,5 +106,5 @@ cleanup, but [check the Cap'n Proto blog for an in-depth discussion][2]. This problem is also caught by capnp/fuzz-test.c++, which *has* been merged into master but likely doesn't have as broad coverage. -[1]: https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md +[1]: https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md [2]: https://capnproto.org/news/2015-03-02-security-advisory-and-integer-overflow-protection.html diff --git a/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-2-all-cpu-amplification.md b/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-2-all-cpu-amplification.md index 94ad3361282..1bc4bccd2a3 100644 --- a/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-2-all-cpu-amplification.md +++ b/libs/EXTERNAL/capnproto/security-advisories/2015-03-02-2-all-cpu-amplification.md @@ -35,7 +35,7 @@ Fixed in - Unix: https://capnproto.org/capnproto-c++-0.4.1.1.tar.gz - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/104870608fde3c698483fdef6b97f093fc15685d +[0]: https://github.com/capnproto/capnproto/commit/104870608fde3c698483fdef6b97f093fc15685d Details ======= diff --git a/libs/EXTERNAL/capnproto/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md b/libs/EXTERNAL/capnproto/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md index aee7f1782c0..bd25698d5ed 100644 --- a/libs/EXTERNAL/capnproto/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md +++ b/libs/EXTERNAL/capnproto/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md @@ -37,7 +37,7 @@ Fixed in - Unix: https://capnproto.org/capnproto-c++-0.4.1.2.tar.gz - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/80149744bdafa3ad4eedc83f8ab675e27baee868 +[0]: https://github.com/capnproto/capnproto/commit/80149744bdafa3ad4eedc83f8ab675e27baee868 Details ======= @@ -55,7 +55,7 @@ loop that doesn't call any application code. Only CPU time is possibly consumed, not RAM or other resources. However, it is still possible to create significant delays for the receiver with a specially-crafted message. -[1]: https://github.com/sandstorm-io/capnproto/blob/master/security-advisories/2015-03-02-2-all-cpu-amplification.md +[1]: https://github.com/capnproto/capnproto/blob/master/security-advisories/2015-03-02-2-all-cpu-amplification.md Preventative measures ===================== diff --git a/libs/EXTERNAL/capnproto/security-advisories/2017-04-17-0-apple-clang-elides-bounds-check.md b/libs/EXTERNAL/capnproto/security-advisories/2017-04-17-0-apple-clang-elides-bounds-check.md index 683b8e3d0de..49221fc26b5 100644 --- a/libs/EXTERNAL/capnproto/security-advisories/2017-04-17-0-apple-clang-elides-bounds-check.md +++ b/libs/EXTERNAL/capnproto/security-advisories/2017-04-17-0-apple-clang-elides-bounds-check.md @@ -42,7 +42,7 @@ Fixed in - Windows: https://capnproto.org/capnproto-c++-win32-0.5.3.1.zip - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/52bc956459a5e83d7c31be95763ff6399e064ae4 +[0]: https://github.com/capnproto/capnproto/commit/52bc956459a5e83d7c31be95763ff6399e064ae4 Details ======= @@ -144,4 +144,4 @@ technically-correct solution has been implemented in the next commit, extensive refactoring, it is not appropriate for cherry-picking, and will only land in versions 0.6 and up. -[2]: https://github.com/sandstorm-io/capnproto/commit/2ca8e41140ebc618b8fb314b393b0a507568cf21 +[2]: https://github.com/capnproto/capnproto/commit/2ca8e41140ebc618b8fb314b393b0a507568cf21 diff --git a/libs/EXTERNAL/capnproto/security-advisories/2022-11-30-0-pointer-list-bounds.md b/libs/EXTERNAL/capnproto/security-advisories/2022-11-30-0-pointer-list-bounds.md new file mode 100644 index 00000000000..50605ce1950 --- /dev/null +++ b/libs/EXTERNAL/capnproto/security-advisories/2022-11-30-0-pointer-list-bounds.md @@ -0,0 +1,127 @@ +Problem +======= + +Out-of-bounds read due to logic error handling list-of-list. + +Discovered by +============= + +David Renshaw <dwrenshaw@gmail.com>, the maintainer of Cap'n Proto's Rust +implementation, which is affected by the same bug. David discovered this bug +while running his own fuzzer. + +Announced +========= + +2022-11-30 + +CVE +=== + +CVE-2022-46149 + +Impact +====== + +- Remotely segfault a peer by sending it a malicious message, if the victim + performs certain actions on a list-of-pointer type. +- Possible exfiltration of memory, if the victim performs additional certain + actions on a list-of-pointer type. +- To be vulnerable, an application must perform a specific sequence of actions, + described below. At present, **we are not aware of any vulnerable + application**, but we advise updating regardless. + +Fixed in +======== + +Unfortunately, the bug is present in inlined code, therefore the fix will +require rebuilding dependent applications. + +C++ fix: + +- git commit [25d34c67863fd960af34fc4f82a7ca3362ee74b9][0] +- release 0.11 (future) +- release 0.10.3: + - Unix: https://capnproto.org/capnproto-c++-0.10.3.tar.gz + - Windows: https://capnproto.org/capnproto-c++-win32-0.10.3.zip +- release 0.9.2: + - Unix: https://capnproto.org/capnproto-c++-0.9.2.tar.gz + - Windows: https://capnproto.org/capnproto-c++-win32-0.9.2.zip +- release 0.8.1: + - Unix: https://capnproto.org/capnproto-c++-0.8.1.tar.gz + - Windows: https://capnproto.org/capnproto-c++-win32-0.8.1.zip +- release 0.7.1: + - Unix: https://capnproto.org/capnproto-c++-0.7.1.tar.gz + - Windows: https://capnproto.org/capnproto-c++-win32-0.7.1.zip + +Rust fix: + +- `capnp` crate version `0.15.2`, `0.14.11`, or `0.13.7`. + +[0]: https://github.com/capnproto/capnproto/commit/25d34c67863fd960af34fc4f82a7ca3362ee74b9 + +Details +======= + +A specially-crafted pointer could escape bounds checking by exploiting +inconsistent handling of pointers when a list-of-structs is downgraded to a +list-of-pointers. + +For an in-depth explanation of how this bug works, see [David Renshaw's +blog post][1]. This details below focus only on determining whether an +application is vulnerable. + +In order to be vulnerable, an application must have certain properties. + +First, the application must accept messages with a schema in which a field has +list-of-pointer type. This includes `List(Text)`, `List(Data)`, +`List(List(T))`, or `List(C)` where `C` is an interface type. In the following +discussion, we will assume this field is named `foo`. + +Second, the application must accept a message of this schema from a malicious +source, where the attacker can maliciously encode the pointer representing the +field `foo`. + +Third, the application must call `getFoo()` to obtain a `List::Reader` for +the field, and then use it in one of the following two ways: + +1. Pass it as the parameter to another message's `setFoo()`, thus copying the + field into a new message. Note that copying the parent struct as a whole + will *not* trigger the bug; the bug only occurs if the specific field `foo` + is get/set on its own. + +2. Convert it into `AnyList::Reader`, and then attempt to access it through + that. This is much less likely; very few apps use the `AnyList` API. + +The dynamic API equivalents of these actions (`capnp/dynamic.h`) are also +affected. + +If the application does these steps, the attacker may be able to cause the +Cap'n Proto implementation to read beyond the end of the message. This could +induce a segmentation fault. Or, worse, data that happened to be in memory +immediately after the message might be returned as if it were part of the +message. In the latter case, if the application then forwards that data back +to the attacker or sends it to another third party, this could result in +exfiltration of secrets. + +Any exfiltration of data would have the following limitations: + +* The attacker could exfiltrate no more than 512 KiB of memory immediately + following the message buffer. + * The attacker chooses in advance how far past the end of the message to + read. + * The attacker's message itself must be larger than the exfiltrated data. + Note that a sufficiently large message buffer will likely be allocated + using mmap() in which case the attack will likely segfault. +* The attack can only work if the 8 bytes immediately following the + exfiltrated data contains a valid in-bounds Cap'n Proto pointer. The + easiest way to achieve this is if the pointer is null, i.e. 8 bytes of zero. + * The attacker must specify exactly how much data to exfiltrate, so must + guess exactly where such a valid pointer will exist. + * If the exfiltrated data is not followed by a valid pointer, the attack + will throw an exception. If an application has chosen to ignore exceptions + (e.g. by compiling with `-fno-exceptions` and not registering an + alternative exception callback) then the attack may be able to proceed + anyway. + +[1]: https://dwrensha.github.io/capnproto-rust/2022/11/30/out_of_bounds_memory_access_bug.html diff --git a/libs/EXTERNAL/capnproto/style-guide.md b/libs/EXTERNAL/capnproto/style-guide.md index d3d47ef0f56..3dc663333da 100644 --- a/libs/EXTERNAL/capnproto/style-guide.md +++ b/libs/EXTERNAL/capnproto/style-guide.md @@ -429,7 +429,7 @@ We use: * Clang for compiling. * `KJ_DBG()` for simple debugging. * Valgrind for complicated debugging. -* [Ekam](https://github.com/sandstorm-io/ekam) for a build system. +* [Ekam](https://github.com/capnproto/ekam) for a build system. * Git for version control. ## Irrelevant formatting rules @@ -458,7 +458,9 @@ There has also never been any agreement on C++ file extensions, for some reason. * Indents are two spaces. * Never use tabs. * Maximum line length is 100 characters. -* Indent a continuation line by four spaces, *or* line them up nicely with the previous line if it makes it easier to read. +* Indent continuation lines for braced init lists by two spaces. +* Indent all other continuation lines by four spaces. +* Alternatively, line up continuation lines with previous lines if it makes them easier to read. * Place a space between a keyword and an open parenthesis, e.g.: `if (foo)` * Do not place a space between a function name and an open parenthesis, e.g.: `foo(bar)` * Place an opening brace at the end of the statement which initiates the block, not on its own line. @@ -469,6 +471,7 @@ There has also never been any agreement on C++ file extensions, for some reason. * Statements inside a `namespace` are **not** indented unless the namespace is a short block that is just forward-declaring things at the top of a file. * Set your editor to strip trailing whitespace on save, otherwise other people who use this setting will see spurious diffs when they edit a file after you. +
if (foo) { bar(); diff --git a/libs/EXTERNAL/capnproto/super-test.sh b/libs/EXTERNAL/capnproto/super-test.sh index e578105c97c..8d63659d09c 100755 --- a/libs/EXTERNAL/capnproto/super-test.sh +++ b/libs/EXTERNAL/capnproto/super-test.sh @@ -122,6 +122,8 @@ while [ $# -gt 0 ]; do shift ;; clang* ) + # Need to set CC as well for configure to handle -fcoroutines-ts. + export CC=clang${1#clang} export CXX=clang++${1#clang} if [ "$1" != "clang-5.0" ]; then export LIB_FUZZING_ENGINE=-fsanitize=fuzzer @@ -400,11 +402,36 @@ fi if [ $IS_CLANG = yes ]; then # Don't fail out on this ridiculous "argument unused during compilation" warning. export CXXFLAGS="$CXXFLAGS -Wno-error=unused-command-line-argument" + + # Enable coroutines if supported. + if [ "${CXX#*-}" -ge 14 ] 2>/dev/null; then + # Somewhere between version 10 and 14, Clang started supporting coroutines as a C++20 feature, + # and started issuing deprecating warnings for -fcoroutines-ts. (I'm not sure which version it + # was exactly since our CI jumped from 10 to 14, so I'm somewhat arbitrarily choosing 14 as the + # cutoff.) + export CXXFLAGS="$CXXFLAGS -std=c++20 -stdlib=libc++" + export LDFLAGS="-stdlib=libc++" + + # TODO(someday): On Ubuntu 22.04, clang-14 with -stdlib=libc++ fails to link with libfuzzer, + # which looks like it might itself be linked against libstdc++? Need to investigate. + unset LIB_FUZZING_ENGINE + elif [ "${CXX#*-}" -ge 10 ] 2>/dev/null; then + # At the moment, only our clang-10 CI run seems to like -fcoroutines-ts. Earlier versions seem + # to have a library misconfiguration causing ./configure to result in the following error: + # conftest.cpp:12:12: fatal error: 'initializer_list' file not found + # #include + # Let's use any clang version >= 10 so that if we move to a newer version, we'll get additional + # coverage by default. + export CXXFLAGS="$CXXFLAGS -std=gnu++17 -stdlib=libc++ -fcoroutines-ts" + export LDFLAGS="-fcoroutines-ts -stdlib=libc++" + fi else # GCC emits uninitialized warnings all over and they seem bogus. We use valgrind to test for # uninitialized memory usage later on. GCC 4 also emits strange bogus warnings with # -Wstrict-overflow, so we disable it. CXXFLAGS="$CXXFLAGS -Wno-maybe-uninitialized -Wno-strict-overflow" + + # TODO(someday): Enable coroutines in g++ if supported. fi cd c++ diff --git a/vpr/src/pack/cluster_util.cpp b/vpr/src/pack/cluster_util.cpp index fe398bebb09..f8d0b5d9f9e 100644 --- a/vpr/src/pack/cluster_util.cpp +++ b/vpr/src/pack/cluster_util.cpp @@ -1245,7 +1245,7 @@ enum e_block_pack_status try_place_atom_block_rec(const t_pb_graph_node* pb_grap } else { /* if this is not the first child of this parent, must match existing parent mode */ if (parent_pb->mode != pb_graph_node->pb_type->parent_mode->index) { - return BLK_FAILED_FEASIBLE; + return e_block_pack_status::BLK_FAILED_FEASIBLE; } }