Skip to content

Commit a8c8a9b

Browse files
committed
Fix a major stdin wakeup race condition
1 parent 0b4f966 commit a8c8a9b

16 files changed

+172
-325
lines changed

Diff for: src/host/ApiRoutines.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ class ApiRoutines : public IApiRoutines
5656
const bool IsUnicode,
5757
const bool IsPeek,
5858
const bool IsWaitAllowed,
59-
std::unique_ptr<IWaitRoutine>& waiter) noexcept override;
59+
CONSOLE_API_MSG* pWaitReplyMessage) noexcept override;
6060

6161
[[nodiscard]] HRESULT ReadConsoleImpl(IConsoleInputObject& context,
6262
std::span<char> buffer,
6363
size_t& written,
64-
std::unique_ptr<IWaitRoutine>& waiter,
64+
CONSOLE_API_MSG* pWaitReplyMessage,
6565
const std::wstring_view initialData,
6666
const std::wstring_view exeName,
6767
INPUT_READ_HANDLE_DATA& readHandleState,
@@ -73,12 +73,12 @@ class ApiRoutines : public IApiRoutines
7373
[[nodiscard]] HRESULT WriteConsoleAImpl(IConsoleOutputObject& context,
7474
const std::string_view buffer,
7575
size_t& read,
76-
std::unique_ptr<IWaitRoutine>& waiter) noexcept override;
76+
CONSOLE_API_MSG* pWaitReplyMessage) noexcept override;
7777

7878
[[nodiscard]] HRESULT WriteConsoleWImpl(IConsoleOutputObject& context,
7979
const std::wstring_view buffer,
8080
size_t& read,
81-
std::unique_ptr<IWaitRoutine>& waiter) noexcept override;
81+
CONSOLE_API_MSG* pWaitReplyMessage) noexcept override;
8282

8383
#pragma region ThreadCreationInfo
8484
[[nodiscard]] HRESULT GetConsoleLangIdImpl(LANGID& langId) noexcept override;

Diff for: src/host/_stream.cpp

+57-125
Original file line numberDiff line numberDiff line change
@@ -415,30 +415,9 @@ void WriteClearScreen(SCREEN_INFORMATION& screenInfo)
415415
// - pwchBuffer - wide character text to be inserted into buffer
416416
// - pcbBuffer - byte count of pwchBuffer on the way in, number of bytes consumed on the way out.
417417
// - screenInfo - Screen Information class to write the text into at the current cursor position
418-
// - ppWaiter - If writing to the console is blocked for whatever reason, this will be filled with a pointer to context
419-
// that can be used by the server to resume the call at a later time.
420-
// Return Value:
421-
// - STATUS_SUCCESS if OK.
422-
// - CONSOLE_STATUS_WAIT if we couldn't finish now and need to be called back later (see ppWaiter).
423-
// - Or a suitable NTSTATUS format error code for memory/string/math failures.
424-
[[nodiscard]] NTSTATUS DoWriteConsole(_In_reads_bytes_(*pcbBuffer) PCWCHAR pwchBuffer,
425-
_Inout_ size_t* const pcbBuffer,
426-
SCREEN_INFORMATION& screenInfo,
427-
std::unique_ptr<WriteData>& waiter)
418+
[[nodiscard]] HRESULT DoWriteConsole(SCREEN_INFORMATION& screenInfo, std::wstring_view str)
428419
try
429420
{
430-
auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation();
431-
if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING)))
432-
{
433-
waiter = std::make_unique<WriteData>(screenInfo,
434-
pwchBuffer,
435-
*pcbBuffer,
436-
gci.OutputCP);
437-
return CONSOLE_STATUS_WAIT;
438-
}
439-
440-
const std::wstring_view str{ pwchBuffer, *pcbBuffer / sizeof(WCHAR) };
441-
442421
if (WI_IsAnyFlagClear(screenInfo.OutputMode, ENABLE_VIRTUAL_TERMINAL_PROCESSING | ENABLE_PROCESSED_OUTPUT))
443422
{
444423
WriteCharsLegacy(screenInfo, str, nullptr);
@@ -447,55 +426,9 @@ try
447426
{
448427
WriteCharsVT(screenInfo, str);
449428
}
450-
451-
return STATUS_SUCCESS;
452-
}
453-
NT_CATCH_RETURN()
454-
455-
// Routine Description:
456-
// - This method performs the actual work of attempting to write to the console, converting data types as necessary
457-
// to adapt from the server types to the legacy internal host types.
458-
// - It operates on Unicode data only. It's assumed the text is translated by this point.
459-
// Arguments:
460-
// - OutContext - the console output object to write the new text into
461-
// - pwsTextBuffer - wide character text buffer provided by client application to insert
462-
// - cchTextBufferLength - text buffer counted in characters
463-
// - pcchTextBufferRead - character count of the number of characters we were able to insert before returning
464-
// - ppWaiter - If we are blocked from writing now and need to wait, this is filled with contextual data for the server to restore the call later
465-
// Return Value:
466-
// - S_OK if successful.
467-
// - S_OK if we need to wait (check if ppWaiter is not nullptr).
468-
// - Or a suitable HRESULT code for math/string/memory failures.
469-
[[nodiscard]] HRESULT WriteConsoleWImplHelper(IConsoleOutputObject& context,
470-
const std::wstring_view buffer,
471-
size_t& read,
472-
std::unique_ptr<WriteData>& waiter) noexcept
473-
{
474-
try
475-
{
476-
// Set out variables in case we exit early.
477-
read = 0;
478-
waiter.reset();
479-
480-
// Convert characters to bytes to give to DoWriteConsole.
481-
size_t cbTextBufferLength;
482-
RETURN_IF_FAILED(SizeTMult(buffer.size(), sizeof(wchar_t), &cbTextBufferLength));
483-
484-
auto Status = DoWriteConsole(const_cast<wchar_t*>(buffer.data()), &cbTextBufferLength, context, waiter);
485-
486-
// Convert back from bytes to characters for the resulting string length written.
487-
read = cbTextBufferLength / sizeof(wchar_t);
488-
489-
if (Status == CONSOLE_STATUS_WAIT)
490-
{
491-
FAIL_FAST_IF_NULL(waiter.get());
492-
Status = STATUS_SUCCESS;
493-
}
494-
495-
RETURN_NTSTATUS(Status);
496-
}
497-
CATCH_RETURN();
429+
return S_OK;
498430
}
431+
CATCH_RETURN()
499432

500433
// Routine Description:
501434
// - Writes non-Unicode formatted data into the given console output object.
@@ -514,13 +447,12 @@ NT_CATCH_RETURN()
514447
[[nodiscard]] HRESULT ApiRoutines::WriteConsoleAImpl(IConsoleOutputObject& context,
515448
const std::string_view buffer,
516449
size_t& read,
517-
std::unique_ptr<IWaitRoutine>& waiter) noexcept
450+
CONSOLE_API_MSG* pWaitReplyMessage) noexcept
518451
{
519452
try
520453
{
521454
// Ensure output variables are initialized.
522455
read = 0;
523-
waiter.reset();
524456

525457
if (buffer.empty())
526458
{
@@ -620,67 +552,63 @@ NT_CATCH_RETURN()
620552
wstr.resize((dbcsLength + mbPtrLength) / sizeof(wchar_t));
621553
}
622554

623-
// Hold the specific version of the waiter locally so we can tinker with it if we have to store additional context.
624-
std::unique_ptr<WriteData> writeDataWaiter{};
625-
626-
// Make the W version of the call
627-
size_t wcBufferWritten{};
628-
const auto hr{ WriteConsoleWImplHelper(screenInfo, wstr, wcBufferWritten, writeDataWaiter) };
629-
630-
// If there is no waiter, process the byte count now.
631-
if (nullptr == writeDataWaiter.get())
555+
auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation();
556+
if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING)))
632557
{
633-
// Calculate how many bytes of the original A buffer were consumed in the W version of the call to satisfy mbBufferRead.
634-
// For UTF-8 conversions, we've already returned this information above.
635-
if (CP_UTF8 != codepage)
636-
{
637-
size_t mbBufferRead{};
638-
639-
// Start by counting the number of A bytes we used in printing our W string to the screen.
640-
try
641-
{
642-
mbBufferRead = GetALengthFromW(codepage, { wstr.data(), wcBufferWritten });
643-
}
644-
CATCH_LOG();
645-
646-
// If we captured a byte off the string this time around up above, it means we didn't feed
647-
// it into the WriteConsoleW above, and therefore its consumption isn't accounted for
648-
// in the count we just made. Add +1 to compensate.
649-
if (leadByteCaptured)
650-
{
651-
mbBufferRead++;
652-
}
558+
const auto waiter = new WriteData(screenInfo, std::move(wstr), gci.OutputCP);
653559

654-
// If we consumed an internally-stored lead byte this time around up above, it means that we
655-
// fed a byte into WriteConsoleW that wasn't a part of this particular call's request.
656-
// We need to -1 to compensate and tell the caller the right number of bytes consumed this request.
657-
if (leadByteConsumed)
658-
{
659-
mbBufferRead--;
660-
}
661-
662-
read = mbBufferRead;
663-
}
664-
}
665-
else
666-
{
667560
// If there is a waiter, then we need to stow some additional information in the wait structure so
668561
// we can synthesize the correct byte count later when the wait routine is triggered.
669562
if (CP_UTF8 != codepage)
670563
{
671564
// For non-UTF8 codepages, save the lead byte captured/consumed data so we can +1 or -1 the final decoded count
672565
// in the WaitData::Notify method later.
673-
writeDataWaiter->SetLeadByteAdjustmentStatus(leadByteCaptured, leadByteConsumed);
566+
waiter->SetLeadByteAdjustmentStatus(leadByteCaptured, leadByteConsumed);
674567
}
675568
else
676569
{
677570
// For UTF8 codepages, just remember the consumption count from the UTF-8 parser.
678-
writeDataWaiter->SetUtf8ConsumedCharacters(read);
571+
waiter->SetUtf8ConsumedCharacters(read);
679572
}
573+
574+
std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, waiter);
575+
return CONSOLE_STATUS_WAIT;
680576
}
681577

682-
// Give back the waiter now that we're done with tinkering with it.
683-
waiter.reset(writeDataWaiter.release());
578+
// Make the W version of the call
579+
const auto hr = DoWriteConsole(screenInfo, wstr);
580+
581+
// Calculate how many bytes of the original A buffer were consumed in the W version of the call to satisfy mbBufferRead.
582+
// For UTF-8 conversions, we've already returned this information above.
583+
if (CP_UTF8 != codepage)
584+
{
585+
size_t mbBufferRead{};
586+
587+
// Start by counting the number of A bytes we used in printing our W string to the screen.
588+
try
589+
{
590+
mbBufferRead = GetALengthFromW(codepage, wstr);
591+
}
592+
CATCH_LOG();
593+
594+
// If we captured a byte off the string this time around up above, it means we didn't feed
595+
// it into the WriteConsoleW above, and therefore its consumption isn't accounted for
596+
// in the count we just made. Add +1 to compensate.
597+
if (leadByteCaptured)
598+
{
599+
mbBufferRead++;
600+
}
601+
602+
// If we consumed an internally-stored lead byte this time around up above, it means that we
603+
// fed a byte into WriteConsoleW that wasn't a part of this particular call's request.
604+
// We need to -1 to compensate and tell the caller the right number of bytes consumed this request.
605+
if (leadByteConsumed)
606+
{
607+
mbBufferRead--;
608+
}
609+
610+
read = mbBufferRead;
611+
}
684612

685613
return hr;
686614
}
@@ -703,20 +631,24 @@ NT_CATCH_RETURN()
703631
[[nodiscard]] HRESULT ApiRoutines::WriteConsoleWImpl(IConsoleOutputObject& context,
704632
const std::wstring_view buffer,
705633
size_t& read,
706-
std::unique_ptr<IWaitRoutine>& waiter) noexcept
634+
CONSOLE_API_MSG* pWaitReplyMessage) noexcept
707635
{
708636
try
709637
{
710638
LockConsole();
711639
auto unlock = wil::scope_exit([&] { UnlockConsole(); });
712640

713-
std::unique_ptr<WriteData> writeDataWaiter;
714-
RETURN_IF_FAILED(WriteConsoleWImplHelper(context.GetActiveBuffer(), buffer, read, writeDataWaiter));
715-
716-
// Transfer specific waiter pointer into the generic interface wrapper.
717-
waiter.reset(writeDataWaiter.release());
641+
auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation();
642+
if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING)))
643+
{
644+
std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, new WriteData(context, std::wstring{ buffer }, gci.OutputCP));
645+
return CONSOLE_STATUS_WAIT;
646+
}
718647

719-
return S_OK;
648+
read = 0;
649+
auto Status = DoWriteConsole(context, buffer);
650+
read = buffer.size();
651+
return Status;
720652
}
721653
CATCH_RETURN();
722654
}

Diff for: src/host/_stream.h

+1-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,4 @@ void WriteClearScreen(SCREEN_INFORMATION& screenInfo);
2525

2626
// NOTE: console lock must be held when calling this routine
2727
// String has been translated to unicode at this point.
28-
[[nodiscard]] NTSTATUS DoWriteConsole(_In_reads_bytes_(pcbBuffer) const wchar_t* pwchBuffer,
29-
_Inout_ size_t* const pcbBuffer,
30-
SCREEN_INFORMATION& screenInfo,
31-
std::unique_ptr<WriteData>& waiter);
28+
[[nodiscard]] HRESULT DoWriteConsole(SCREEN_INFORMATION& screenInfo, std::wstring_view str);

Diff for: src/host/directio.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,10 @@ using Microsoft::Console::Interactivity::ServiceLocator;
5858
const bool IsUnicode,
5959
const bool IsPeek,
6060
const bool IsWaitAllowed,
61-
std::unique_ptr<IWaitRoutine>& waiter) noexcept
61+
CONSOLE_API_MSG* pWaitReplyMessage) noexcept
6262
{
6363
try
6464
{
65-
waiter.reset();
66-
6765
if (eventReadCount == 0)
6866
{
6967
return STATUS_SUCCESS;
@@ -83,9 +81,7 @@ using Microsoft::Console::Interactivity::ServiceLocator;
8381
{
8482
// If we're told to wait until later, move all of our context
8583
// to the read data object and send it back up to the server.
86-
waiter = std::make_unique<DirectReadData>(&inputBuffer,
87-
&readHandleState,
88-
eventReadCount);
84+
std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, new DirectReadData(&inputBuffer, &readHandleState, eventReadCount));
8985
}
9086
return Status;
9187
}

Diff for: src/host/server.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ class CONSOLE_INFORMATION :
166166
MidiAudio _midiAudio;
167167
};
168168

169-
#define CONSOLE_STATUS_WAIT 0xC0030001
170-
#define CONSOLE_STATUS_READ_COMPLETE 0xC0030002
171-
#define CONSOLE_STATUS_WAIT_NO_BLOCK 0xC0030003
169+
#define CONSOLE_STATUS_WAIT ((HRESULT)0xC0030001)
170+
#define CONSOLE_STATUS_READ_COMPLETE ((HRESULT)0xC0030002)
171+
#define CONSOLE_STATUS_WAIT_NO_BLOCK ((HRESULT)0xC0030003)
172172

173173
#include "../server/ObjectHandle.h"
174174

0 commit comments

Comments
 (0)