diff --git a/build/replace_just_ksa_x64.bat b/build/replace_just_ksa_x64.bat
new file mode 100644
index 00000000..e1819858
--- /dev/null
+++ b/build/replace_just_ksa_x64.bat
@@ -0,0 +1,21 @@
+@echo off
+echo This must be run as administrator.
+
+set servicepath="%ProgramFiles%\Windows MIDI Services\Service"
+set apipath="%ProgramFiles%\Windows MIDI Services\API"
+set dmppath="%ProgramFiles%\Windows MIDI Services\"
+set buildoutput="%midi_repo_root%src\api\VSFiles\x64\Release"
+
+
+
+echo Stopping midisrv
+net stop midisrv
+
+echo Copying KSA Transport
+mkdir %servicepath%
+copy /Y %buildoutput%\Midi2.KSAggregateTransport.dll %servicepath%
+regsvr32 %servicepath%\Midi2.KSAggregateTransport.dll
+
+net start midisrv
+
+pause
\ No newline at end of file
diff --git a/build/staging/version/BundleInfo.wxi b/build/staging/version/BundleInfo.wxi
index 1a39f949..0f37e7c4 100644
--- a/build/staging/version/BundleInfo.wxi
+++ b/build/staging/version/BundleInfo.wxi
@@ -1,5 +1,5 @@
-
-
+
+
diff --git a/build/staging/version/WindowsMidiServicesVersion.cs b/build/staging/version/WindowsMidiServicesVersion.cs
index 205a9fa6..f385244e 100644
--- a/build/staging/version/WindowsMidiServicesVersion.cs
+++ b/build/staging/version/WindowsMidiServicesVersion.cs
@@ -9,15 +9,15 @@ public static class MidiNuGetBuildInformation
{
public const bool IsPreview = true;
public const string Source = "GitHub Preview";
- public const string BuildDate = "2026-01-19";
+ public const string BuildDate = "2026-02-16";
public const string Name = "Service Preview 14";
- public const string BuildFullVersion = "1.0.15-preview.14.73";
+ public const string BuildFullVersion = "1.0.15-preview.14.79";
public const ushort VersionMajor = 1;
public const ushort VersionMinor = 0;
public const ushort VersionPatch = 15;
- public const ushort VersionBuildNumber = 73;
- public const string Preview = "preview.14.73";
- public const string AssemblyFullVersion = "1.0.15.73";
- public const string FileFullVersion = "1.0.15.73";
+ public const ushort VersionBuildNumber = 79;
+ public const string Preview = "preview.14.79";
+ public const string AssemblyFullVersion = "1.0.15.79";
+ public const string FileFullVersion = "1.0.15.79";
}
}
diff --git a/build/staging/version/WindowsMidiServicesVersion.h b/build/staging/version/WindowsMidiServicesVersion.h
index 99b23347..77ede814 100644
--- a/build/staging/version/WindowsMidiServicesVersion.h
+++ b/build/staging/version/WindowsMidiServicesVersion.h
@@ -7,14 +7,14 @@
#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_IS_PREVIEW true
#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_SOURCE L"GitHub Preview"
-#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_DATE L"2026-01-19"
+#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_DATE L"2026-02-16"
#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_NAME L"Service Preview 14"
-#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FULL L"1.0.15-preview.14.73"
+#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FULL L"1.0.15-preview.14.79"
#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_MAJOR 1
#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_MINOR 0
#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_PATCH 15
-#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_BUILD_NUMBER 73
-#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_PREVIEW L"preview.14.73"
-#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FILE L"1.0.15.73"
+#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_BUILD_NUMBER 79
+#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_PREVIEW L"preview.14.79"
+#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FILE L"1.0.15.79"
#endif
diff --git a/src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj b/src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj
index f205a17a..377f6c9b 100644
--- a/src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj
+++ b/src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj
@@ -79,24 +79,28 @@
USBMidi2
$(SolutionDir)VSFiles\$(Platform)\$(Configuration)\
$(SolutionDir)VSFiles\intermediate\$(ProjectName)\$(Platform)\$(Configuration)\
+ $(IncludePath);$(KMDF_INC_PATH)$(KMDF_VER_PATH);$(SolutionDir)\inc
DbgengKernelDebugger
USBMidi2
$(SolutionDir)VSFiles\$(Platform)\$(Configuration)\
$(SolutionDir)VSFiles\intermediate\$(ProjectName)\$(Platform)\$(Configuration)\
+ $(IncludePath);$(KMDF_INC_PATH)$(KMDF_VER_PATH);$(SolutionDir)\inc
DbgengKernelDebugger
USBMidi2
$(SolutionDir)VSFiles\$(Platform)\$(Configuration)\
$(SolutionDir)VSFiles\intermediate\$(ProjectName)\$(Platform)\$(Configuration)\
+ $(IncludePath);$(KMDF_INC_PATH)$(KMDF_VER_PATH);$(SolutionDir)\inc
DbgengKernelDebugger
USBMidi2
$(SolutionDir)VSFiles\$(Platform)\$(Configuration)\
$(SolutionDir)VSFiles\intermediate\$(ProjectName)\$(Platform)\$(Configuration)\
+ $(IncludePath);$(KMDF_INC_PATH)$(KMDF_VER_PATH);$(SolutionDir)\inc
false
diff --git a/src/api/Inc/Feature_Servicing_MIDI2VirtualPortDriversFix.h b/src/api/Inc/Feature_Servicing_MIDI2VirtualPortDriversFix.h
new file mode 100644
index 00000000..4c166d7f
--- /dev/null
+++ b/src/api/Inc/Feature_Servicing_MIDI2VirtualPortDriversFix.h
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License
+// ============================================================================
+// This is part of the Windows MIDI Services App API and should be used
+// in your Windows application via an official binary distribution.
+// Further information: https://aka.ms/midi
+// ============================================================================
+
+#pragma once
+
+class Feature_Servicing_MIDI2VirtualPortDriversFix
+{
+public:
+ static bool IsEnabled()
+ {
+ return true;
+ }
+};
\ No newline at end of file
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp
index cffab844..219425de 100644
--- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp
@@ -11,7 +11,6 @@
#include "ump_iterator.h"
-
_Use_decl_annotations_
HRESULT
CMidi2KSAggregateMidi::Initialize(
@@ -205,40 +204,84 @@ CMidi2KSAggregateMidi::Initialize(
wil::com_ptr_nothrow proxy;
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&proxy));
- auto initResult =
- proxy->Initialize(
- endpointDeviceInterfaceId,
- handleDupe.get(),
- pinMapEntry->PinId,
- requestedBufferSize,
- mmCssTaskId,
- m_callback,
- context,
- pinMapEntry->GroupIndex
- );
-
- if (SUCCEEDED(initResult))
+ // needed for internal consumption. Gary to replace this with feature enablement check
+ // defined in pch.h
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
{
- m_midiInDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy));
+ auto initResult =
+ proxy->Initialize(
+ filterInterfaceId.c_str(),
+ handleDupe.get(),
+ pinMapEntry->PinId,
+ requestedBufferSize,
+ mmCssTaskId,
+ m_callback,
+ context,
+ pinMapEntry->GroupIndex
+ );
+
+ if (SUCCEEDED(initResult))
+ {
+ m_midiInDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy));
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_ERROR,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Unable to initialize Midi Input proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD),
+ TraceLoggingUInt32(requestedBufferSize, "buffer size"),
+ TraceLoggingUInt32(pinMapEntry->PinId, "pin id"),
+ TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"),
+ TraceLoggingWideString(filterInterfaceId.c_str(), "filter")
+ );
+
+ RETURN_IF_FAILED(initResult);
+ }
}
else
{
- TraceLoggingWrite(
- MidiKSAggregateTransportTelemetryProvider::Provider(),
- MIDI_TRACE_EVENT_ERROR,
- TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
- TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
- TraceLoggingPointer(this, "this"),
- TraceLoggingWideString(L"Unable to initialize Midi Input proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD),
- TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD),
- TraceLoggingUInt32(requestedBufferSize, "buffer size"),
- TraceLoggingUInt32(pinMapEntry->PinId, "pin id"),
- TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"),
- TraceLoggingWideString(filterInterfaceId.c_str(), "filter")
- );
-
- RETURN_IF_FAILED(initResult);
+ auto initResult =
+ proxy->Initialize(
+ endpointDeviceInterfaceId,
+ handleDupe.get(),
+ pinMapEntry->PinId,
+ requestedBufferSize,
+ mmCssTaskId,
+ m_callback,
+ context,
+ pinMapEntry->GroupIndex
+ );
+
+ if (SUCCEEDED(initResult))
+ {
+ m_midiInDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy));
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_ERROR,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Unable to initialize Midi Input proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD),
+ TraceLoggingUInt32(requestedBufferSize, "buffer size"),
+ TraceLoggingUInt32(pinMapEntry->PinId, "pin id"),
+ TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"),
+ TraceLoggingWideString(filterInterfaceId.c_str(), "filter")
+ );
+
+ RETURN_IF_FAILED(initResult);
+ }
}
+
+
}
else if (pinMapEntry->PinDataFlow == MidiFlow::MidiFlowIn)
{
@@ -247,38 +290,79 @@ CMidi2KSAggregateMidi::Initialize(
wil::com_ptr_nothrow proxy;
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&proxy));
- auto initResult =
- proxy->Initialize(
- endpointDeviceInterfaceId,
- handleDupe.get(),
- pinMapEntry->PinId,
- requestedBufferSize,
- mmCssTaskId,
- context,
- pinMapEntry->GroupIndex
- );
-
- if (SUCCEEDED(initResult))
+ // needed for internal consumption. Gary to replace this with feature enablement check
+ // defined in pch.h
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
{
- m_midiOutDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy));
+ auto initResult =
+ proxy->Initialize(
+ filterInterfaceId.c_str(),
+ handleDupe.get(),
+ pinMapEntry->PinId,
+ requestedBufferSize,
+ mmCssTaskId,
+ context,
+ pinMapEntry->GroupIndex
+ );
+
+ if (SUCCEEDED(initResult))
+ {
+ m_midiOutDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy));
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_ERROR,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Unable to initialize Midi Output proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD),
+ TraceLoggingUInt32(requestedBufferSize, "buffer size"),
+ TraceLoggingUInt32(pinMapEntry->PinId, "pin id"),
+ TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"),
+ TraceLoggingWideString(filterInterfaceId.c_str(), "filter")
+ );
+
+ RETURN_IF_FAILED(initResult);
+ }
}
else
{
- TraceLoggingWrite(
- MidiKSAggregateTransportTelemetryProvider::Provider(),
- MIDI_TRACE_EVENT_ERROR,
- TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
- TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
- TraceLoggingPointer(this, "this"),
- TraceLoggingWideString(L"Unable to initialize Midi Output proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD),
- TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD),
- TraceLoggingUInt32(requestedBufferSize, "buffer size"),
- TraceLoggingUInt32(pinMapEntry->PinId, "pin id"),
- TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"),
- TraceLoggingWideString(filterInterfaceId.c_str(), "filter")
- );
-
- RETURN_IF_FAILED(initResult);
+ auto initResult =
+ proxy->Initialize(
+ endpointDeviceInterfaceId,
+ handleDupe.get(),
+ pinMapEntry->PinId,
+ requestedBufferSize,
+ mmCssTaskId,
+ context,
+ pinMapEntry->GroupIndex
+ );
+
+ if (SUCCEEDED(initResult))
+ {
+ m_midiOutDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy));
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_ERROR,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Unable to initialize Midi Output proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD),
+ TraceLoggingUInt32(requestedBufferSize, "buffer size"),
+ TraceLoggingUInt32(pinMapEntry->PinId, "pin id"),
+ TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"),
+ TraceLoggingWideString(filterInterfaceId.c_str(), "filter")
+ );
+
+ RETURN_IF_FAILED(initResult);
+ }
}
}
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiBidi.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiBidi.cpp
index 3e552660..64692c74 100644
--- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiBidi.cpp
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiBidi.cpp
@@ -35,6 +35,7 @@ CMidi2KSAggregateMidiBidi::Initialize(
TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
+ TraceLoggingPointer(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(device, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD)
);
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiConfigurationManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiConfigurationManager.cpp
index 12f393c9..927ad0e8 100644
--- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiConfigurationManager.cpp
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiConfigurationManager.cpp
@@ -258,12 +258,26 @@ CMidi2KSAggregateMidiConfigurationManager::UpdateConfiguration(
// Resolve the EndpointDeviceId in case we matched on something else
winrt::hstring matchingEndpointDeviceId{};
- auto em = TransportState::Current().GetEndpointManager();
- if (em != nullptr)
+
+
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
+ {
+ auto em = TransportState::Current().GetEndpointManager2();
+ if (em != nullptr)
+ {
+ matchingEndpointDeviceId = em->FindMatchingInstantiatedEndpoint(*matchCriteria);
+ }
+ }
+ else
{
- matchingEndpointDeviceId = em->FindMatchingInstantiatedEndpoint(*matchCriteria);
+ auto em = TransportState::Current().GetEndpointManager();
+ if (em != nullptr)
+ {
+ matchingEndpointDeviceId = em->FindMatchingInstantiatedEndpoint(*matchCriteria);
+ }
}
+
// process all the custom props like Name, Description, Image, etc.
LOG_IF_FAILED(ProcessCustomProperties(
matchingEndpointDeviceId,
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp
new file mode 100644
index 00000000..937de3f8
--- /dev/null
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp
@@ -0,0 +1,2439 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License
+// ============================================================================
+// This is part of the Windows MIDI Services App API and should be used
+// in your Windows application via an official binary distribution.
+// Further information: https://aka.ms/midi
+// ============================================================================
+
+
+
+#include "pch.h"
+
+#include // for the string stream in parsing of VID/PID/Serial from parent id
+#include // for getline for string parsing of VID/PID/Serial from parent id
+
+#include "Feature_Servicing_MIDI2FilterCreations.h"
+
+using namespace wil;
+using namespace winrt::Windows::Devices::Enumeration;
+using namespace winrt::Windows::Foundation;
+using namespace winrt::Windows::Foundation::Collections;
+using namespace Microsoft::WRL;
+using namespace Microsoft::WRL::Wrappers;
+
+#define INITIAL_ENUMERATION_TIMEOUT_MS 10000
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::Initialize(
+ IMidiDeviceManager* midiDeviceManager,
+ IMidiEndpointProtocolManager* midiEndpointProtocolManager
+)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+ RETURN_HR_IF(E_INVALIDARG, nullptr == midiDeviceManager);
+ RETURN_HR_IF(E_INVALIDARG, nullptr == midiEndpointProtocolManager);
+
+ RETURN_IF_FAILED(midiDeviceManager->QueryInterface(__uuidof(IMidiDeviceManager), (void**)&m_midiDeviceManager));
+ RETURN_IF_FAILED(midiEndpointProtocolManager->QueryInterface(__uuidof(IMidiEndpointProtocolManager), (void**)&m_midiProtocolManager));
+
+ // needed for internal consumption. Gary to replace this with feature enablement check
+ // defined in pch.h
+ DWORD individualInterfaceEnumTimeoutMS{ DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS };
+ if (SUCCEEDED(wil::reg::get_value_dword_nothrow(HKEY_LOCAL_MACHINE, MIDI_ROOT_REG_KEY, KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE, &individualInterfaceEnumTimeoutMS)))
+ {
+ individualInterfaceEnumTimeoutMS = max(individualInterfaceEnumTimeoutMS, KSA_INTERFACE_ENUM_TIMEOUT_MS_MINIMUM_VALUE);
+ individualInterfaceEnumTimeoutMS = min(individualInterfaceEnumTimeoutMS, KSA_INTERFACE_ENUM_TIMEOUT_MS_MAXIMUM_VALUE);
+
+ m_individualInterfaceEnumTimeoutMS = individualInterfaceEnumTimeoutMS;
+ }
+ else
+ {
+ m_individualInterfaceEnumTimeoutMS = DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS;
+ }
+
+
+ // the ksa2603 fix enumerates device interfaces instead of parent devices
+
+ winrt::hstring deviceInterfaceSelector(
+ L"System.Devices.InterfaceClassGuid:=\"{6994AD04-93EF-11D0-A3CC-00A0C9223196}\" AND " \
+ L"System.Devices.InterfaceEnabled: = System.StructuredQueryType.Boolean#True");
+
+ auto additionalProps = winrt::single_threaded_vector();
+ additionalProps.Append(L"System.Devices.Parent");
+
+ m_watcher = DeviceInformation::CreateWatcher(deviceInterfaceSelector);
+
+ auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded);
+ auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved);
+ auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceUpdated);
+
+ auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnDeviceWatcherStopped);
+ auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnEnumerationCompleted);
+
+ m_DeviceAdded = m_watcher.Added(winrt::auto_revoke, deviceAddedHandler);
+ m_DeviceRemoved = m_watcher.Removed(winrt::auto_revoke, deviceRemovedHandler);
+ m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler);
+ m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler);
+ m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler);
+
+ // worker thread to handle endpoint creation, since we're enumerating interfaces now and need to aggregate them
+ m_endpointCreationThreadWakeup.create(wil::EventOptions::ManualReset);
+ std::jthread endpointCreationWorkerThread(std::bind_front(&CMidi2KSAggregateMidiEndpointManager2::EndpointCreationThreadWorker, this));
+ m_endpointCreationThread = std::move(endpointCreationWorkerThread);
+
+ m_watcher.Start();
+
+ // Wait for everything to be created so that they're available immediately after service start.
+ m_EnumerationCompleted.wait(INITIAL_ENUMERATION_TIMEOUT_MS);
+
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
+ {
+ if (m_pendingEndpointDefinitions.size() > 0)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enumeration completed with endpoint definitions left pending.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingUInt32(static_cast(m_pendingEndpointDefinitions.size()), "count pending definitions")
+ );
+ }
+ }
+
+ return S_OK;
+}
+
+
+typedef struct {
+ BYTE GroupIndex; // index (0-15) of the group this pin maps to. It's also use when deciding on WinMM names
+ UINT32 PinId; // KS Pin number
+ MidiFlow PinDataFlow; // an input pin is MidiFlowIn, and from the user's perspective, a MIDI Output
+ std::wstring FilterId; // full filter id for this pin
+} PinMapEntryStagingEntry2;
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::UpdateNameTableWithCustomProperties(
+ std::shared_ptr masterEndpointDefinition,
+ std::shared_ptr customProperties)
+{
+ RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition);
+ RETURN_HR_IF_NULL_EXPECTED(S_OK, customProperties);
+ RETURN_HR_IF(S_OK, customProperties->Midi1Destinations.size() == 0 && customProperties->Midi1Sources.size() == 0);
+
+ for (auto const& pinEntry : masterEndpointDefinition->MidiPins)
+ {
+ if (pinEntry->PinDataFlow == MidiFlow::MidiFlowIn)
+ {
+ // message destination (output port), pin flow is In
+ if (auto customConfiguredName = customProperties->Midi1Destinations.find(pinEntry->GroupIndex);
+ customConfiguredName != customProperties->Midi1Destinations.end())
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Found custom name for a Midi 1 destination.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD),
+ TraceLoggingWideString(customConfiguredName->second.Name.c_str(), "custom name"),
+ TraceLoggingUInt8(pinEntry->GroupIndex, "group index")
+ );
+
+ masterEndpointDefinition->EndpointNameTable.UpdateDestinationEntryCustomName(pinEntry->GroupIndex, customConfiguredName->second.Name);
+ }
+ }
+ else if (pinEntry->PinDataFlow == MidiFlow::MidiFlowOut)
+ {
+ // message source (input port), pin flow is Out
+ if (auto customConfiguredName = customProperties->Midi1Sources.find(pinEntry->GroupIndex);
+ customConfiguredName != customProperties->Midi1Sources.end())
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Found custom name for a Midi 1 source.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD),
+ TraceLoggingWideString(customConfiguredName->second.Name.c_str(), "custom name"),
+ TraceLoggingUInt8(pinEntry->GroupIndex, "group index")
+ );
+
+ masterEndpointDefinition->EndpointNameTable.UpdateSourceEntryCustomName(pinEntry->GroupIndex, customConfiguredName->second.Name);
+ }
+ }
+ }
+
+ return S_OK;
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::BuildPinsAndGroupTerminalBlocksPropertyData(
+ std::shared_ptr masterEndpointDefinition,
+ std::vector& pinMapPropertyData,
+ std::vector& groupTerminalBlocks)
+{
+ RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition);
+
+ uint8_t currentBlockNumber{ 0 };
+ std::vector pinMapEntries{ };
+
+ for (auto const& pin : masterEndpointDefinition->MidiPins)
+ {
+ RETURN_HR_IF(E_INVALIDARG, pin->FilterDeviceId.empty());
+
+ internal::GroupTerminalBlockInternal gtb;
+
+ gtb.Number = ++currentBlockNumber;
+ gtb.GroupCount = 1; // always a single group for aggregate MIDI 1.0 devices
+
+ PinMapEntryStagingEntry2 pinMapEntry{ };
+
+ pinMapEntry.PinId = pin->PinNumber;
+ pinMapEntry.FilterId = pin->FilterDeviceId;
+ pinMapEntry.PinDataFlow = pin->PinDataFlow;
+
+ //MidiFlow flowFromUserPerspective;
+
+ pinMapEntry.GroupIndex = pin->GroupIndex;
+ gtb.FirstGroupIndex = pin->GroupIndex;
+
+ if (pin->PinDataFlow == MidiFlow::MidiFlowIn) // pin in, so user out : A MIDI Destination
+ {
+ gtb.Direction = MIDI_GROUP_TERMINAL_BLOCK_INPUT; // from the pin/gtb's perspective
+
+ auto nameTableEntry = masterEndpointDefinition->EndpointNameTable.GetDestinationEntry(gtb.FirstGroupIndex);
+ if (nameTableEntry != nullptr && nameTableEntry->NewStyleName[0] != static_cast(0))
+ {
+ gtb.Name = internal::TrimmedWStringCopy(nameTableEntry->NewStyleName);
+ }
+ }
+ else if (pin->PinDataFlow == MidiFlow::MidiFlowOut) // pin out, so user in : A MIDI Source
+ {
+ gtb.Direction = MIDI_GROUP_TERMINAL_BLOCK_OUTPUT; // from the pin/gtb's perspective
+ auto nameTableEntry = masterEndpointDefinition->EndpointNameTable.GetSourceEntry(gtb.FirstGroupIndex);
+ if (nameTableEntry != nullptr && nameTableEntry->NewStyleName[0] != static_cast(0))
+ {
+ gtb.Name = internal::TrimmedWStringCopy(nameTableEntry->NewStyleName);
+ }
+ }
+ else
+ {
+ RETURN_IF_FAILED(E_INVALIDARG);
+ }
+
+ // name fallback
+ if (gtb.Name.empty())
+ {
+ gtb.Name = masterEndpointDefinition->EndpointName;
+
+ if (gtb.FirstGroupIndex > 0)
+ {
+ gtb.Name += L" " + std::wstring{ gtb.FirstGroupIndex };
+ }
+ }
+
+ // default values as defined in the MIDI 2.0 USB spec
+ gtb.Protocol = 0x01; // midi 1.0
+ gtb.MaxInputBandwidth = 0x0001; // 31.25 kbps
+ gtb.MaxOutputBandwidth = 0x0001; // 31.25 kbps
+
+ groupTerminalBlocks.push_back(gtb);
+ pinMapEntries.push_back(pinMapEntry);
+ }
+
+ // Write Pin Map Property
+ // =====================================================
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Building pin map property", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name")
+ );
+
+ // build the pin map property value
+ KSAGGMIDI_PIN_MAP_PROPERTY_VALUE pinMap{ };
+
+ size_t totalStringSizesIncludingNulls{ 0 };
+ for (auto const& entry : pinMapEntries)
+ {
+ totalStringSizesIncludingNulls += ((entry.FilterId.length() + 1) * sizeof(wchar_t));
+ }
+
+ size_t totalMemoryBytes{
+ SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_VALUE_HEADER +
+ SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_ENTRY_WITHOUT_STRING * pinMapEntries.size() +
+ totalStringSizesIncludingNulls };
+
+ pinMapPropertyData.resize(totalMemoryBytes);
+ auto currentPos = pinMapPropertyData.data();
+
+ // header
+ auto pinMapHeader = (PKSAGGMIDI_PIN_MAP_PROPERTY_VALUE)currentPos;
+ pinMapHeader->TotalByteCount = (UINT32)totalMemoryBytes;
+ currentPos += SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_VALUE_HEADER;
+
+ for (auto const& entry : pinMapEntries)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Processing Pin Map entry", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"),
+ TraceLoggingUInt32(entry.PinId, "Pin Id"),
+ TraceLoggingWideString(entry.FilterId.c_str(), "Filter Id")
+ );
+
+ PKSAGGMIDI_PIN_MAP_PROPERTY_ENTRY propEntry = (PKSAGGMIDI_PIN_MAP_PROPERTY_ENTRY)currentPos;
+
+ propEntry->ByteCount = (UINT)(SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_ENTRY_WITHOUT_STRING + ((entry.FilterId.length() + 1) * sizeof(wchar_t)));
+ propEntry->GroupIndex = entry.GroupIndex;
+ propEntry->PinDataFlow = entry.PinDataFlow;
+ propEntry->PinId = entry.PinId;
+
+ if (!entry.FilterId.empty())
+ {
+ wcscpy_s((wchar_t*)propEntry->FilterId, entry.FilterId.length() + 1, entry.FilterId.c_str());
+ }
+
+ currentPos += propEntry->ByteCount;
+ }
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"All pin map entries copied to property memory", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name")
+ );
+
+
+ return S_OK;
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint(
+ std::shared_ptr endpointDefinition
+)
+{
+ RETURN_HR_IF_NULL(E_INVALIDARG, endpointDefinition);
+
+ std::shared_ptr parentDevice { nullptr };
+ RETURN_IF_FAILED(FindExistingParentDeviceDefinitionForEndpoint(endpointDefinition, parentDevice));
+ RETURN_HR_IF_NULL(E_UNEXPECTED, parentDevice);
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "endpoint name"),
+ TraceLoggingWideString(parentDevice->DeviceName.c_str(), "device name")
+ );
+
+ DEVPROP_BOOLEAN devPropTrue = DEVPROP_TRUE;
+
+ // we require at least one valid pin, and no more than 32 total pins (16 in, 16 out, max)
+ RETURN_HR_IF(E_INVALIDARG, endpointDefinition->MidiPins.size() < 1);
+ RETURN_HR_IF(E_INVALIDARG, endpointDefinition->MidiPins.size() > 32);
+
+ std::vector interfaceDevProperties;
+
+ MIDIENDPOINTCOMMONPROPERTIES commonProperties{};
+ commonProperties.TransportId = TRANSPORT_LAYER_GUID;
+ commonProperties.EndpointDeviceType = MidiEndpointDeviceType_Normal;
+ commonProperties.FriendlyName = endpointDefinition->EndpointName.c_str();
+ commonProperties.TransportCode = TRANSPORT_CODE;
+ commonProperties.EndpointName = endpointDefinition->EndpointName.c_str();
+ commonProperties.EndpointDescription = nullptr;
+ commonProperties.CustomEndpointName = nullptr;
+ commonProperties.CustomEndpointDescription = nullptr;
+ commonProperties.UniqueIdentifier = parentDevice->SerialNumber.empty() ? nullptr : parentDevice->SerialNumber.c_str();
+ commonProperties.ManufacturerName = parentDevice->ManufacturerName.empty() ? nullptr : parentDevice->ManufacturerName.c_str();
+ commonProperties.SupportedDataFormats = MidiDataFormats::MidiDataFormats_UMP;
+ commonProperties.NativeDataFormat = MidiDataFormats_ByteStream;
+
+ UINT32 capabilities{ 0 };
+ capabilities |= MidiEndpointCapabilities_SupportsMultiClient;
+ capabilities |= MidiEndpointCapabilities_GenerateIncomingTimestamps;
+ capabilities |= MidiEndpointCapabilities_SupportsMidi1Protocol;
+ commonProperties.Capabilities = (MidiEndpointCapabilities)capabilities;
+
+
+ std::vector pinMapPropertyData;
+ std::vector groupTerminalBlocks{ };
+ std::vector nameTablePropertyData;
+
+ RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData(
+ endpointDefinition,
+ pinMapPropertyData,
+ groupTerminalBlocks));
+
+
+ interfaceDevProperties.push_back({ { DEVPKEY_KsAggMidiGroupPinMap, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_BINARY, static_cast(pinMapPropertyData.size()), pinMapPropertyData.data() });
+
+
+ // Write Group Terminal Block Property
+ // =====================================================
+
+ std::vector groupTerminalBlockPropertyData{};
+
+ if (internal::WriteGroupTerminalBlocksToPropertyDataPointer(groupTerminalBlocks, groupTerminalBlockPropertyData))
+ {
+ interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_BINARY, static_cast(groupTerminalBlockPropertyData.size()),
+ (PVOID)groupTerminalBlockPropertyData.data() });
+ }
+ else
+ {
+ // write empty data
+ interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_EMPTY, 0, nullptr });
+ }
+
+
+ // Fold in custom properties, including MIDI 1 port names and naming approach
+ // ===============================================================================
+
+ WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{};
+ matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->EndpointDeviceInstanceId);
+ matchCriteria.UsbVendorId = parentDevice->VID;
+ matchCriteria.UsbProductId = parentDevice->PID;
+ matchCriteria.UsbSerialNumber = parentDevice->SerialNumber;
+ matchCriteria.TransportSuppliedEndpointName = endpointDefinition->EndpointName;
+
+ auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria);
+
+ // rebuild the name table, using the custom properties if present
+ RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(endpointDefinition, customProperties));
+
+ std::wstring customName{ };
+ std::wstring customDescription{ };
+ if (customProperties != nullptr)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD),
+ TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"),
+ TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count")
+ );
+
+ if (!customProperties->Name.empty())
+ {
+ customName = customProperties->Name;
+ commonProperties.CustomEndpointName = customName.c_str();
+ }
+
+ if (!customProperties->Description.empty())
+ {
+ customDescription = customProperties->Description;
+ commonProperties.CustomEndpointDescription = customDescription.c_str();
+ }
+
+ // this includes image, the Midi 1 naming approach, etc.
+ customProperties->WriteNonCommonProperties(interfaceDevProperties);
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD)
+ );
+ }
+
+ // Write Name table property, folding in the custom names we discovered earlier
+ // ===============================================================================================
+ RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(endpointDefinition, customProperties));
+ endpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties);
+
+
+ // Write USB VID/PID Data
+ // =====================================================
+
+ if (parentDevice->VID > 0)
+ {
+ interfaceDevProperties.push_back({ { PKEY_MIDI_UsbVID, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_UINT16, static_cast(sizeof(UINT16)), (PVOID)&parentDevice->VID });
+ }
+ else
+ {
+ interfaceDevProperties.push_back({ { PKEY_MIDI_UsbVID, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_EMPTY, 0, nullptr });
+ }
+
+ if (parentDevice->PID > 0)
+ {
+ interfaceDevProperties.push_back({ { PKEY_MIDI_UsbPID, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_UINT16, static_cast(sizeof(UINT16)), (PVOID)&parentDevice->PID });
+ }
+ else
+ {
+ interfaceDevProperties.push_back({ { PKEY_MIDI_UsbPID, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_EMPTY, 0, nullptr });
+ }
+
+
+ // Despite being a MIDI 1 device, we present as a UMP endpoint, so we need to set
+ // this property so the service can create the MIDI 1 ports without waiting for
+ // function blocks/discovery to complete or timeout (which it never will)
+ interfaceDevProperties.push_back({ { PKEY_MIDI_EndpointDiscoveryProcessComplete, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_BOOLEAN, (ULONG)sizeof(devPropTrue), (PVOID)&devPropTrue });
+
+ SW_DEVICE_CREATE_INFO createInfo{ };
+
+ createInfo.cbSize = sizeof(createInfo);
+ createInfo.pszInstanceId = endpointDefinition->EndpointDeviceInstanceId.c_str();
+ createInfo.CapabilityFlags = SWDeviceCapabilitiesNone;
+ createInfo.pszDeviceDescription = endpointDefinition->EndpointName.c_str();
+
+ // Call the device manager and finish the creation
+
+ HRESULT swdCreationResult;
+ wil::unique_cotaskmem_string newDeviceInterfaceId;
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Activating endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name")
+ );
+
+ // set to true if we only want to create UMP endpoints
+ bool umpOnly = false;
+
+ LOG_IF_FAILED(
+ swdCreationResult = m_midiDeviceManager->ActivateEndpoint(
+ parentDevice->DeviceInstanceId.c_str(),
+ umpOnly,
+ MidiFlow::MidiFlowBidirectional, // bidi only for the UMP endpoint
+ &commonProperties,
+ (ULONG)interfaceDevProperties.size(),
+ (ULONG)0,
+ interfaceDevProperties.data(),
+ nullptr,
+ &createInfo,
+ &newDeviceInterfaceId)
+ );
+
+ if (SUCCEEDED(swdCreationResult))
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Aggregate UMP endpoint created", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name"),
+ TraceLoggingWideString(newDeviceInterfaceId.get(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD)
+ );
+
+ // return new device interface id
+ endpointDefinition->EndpointDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(std::wstring{ newDeviceInterfaceId.get() });
+
+ auto lock = m_activatedEndpointDefinitionsLock.lock();
+
+ // Add to internal endpoint manager
+ m_activatedEndpointDefinitions.insert_or_assign(
+ internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice->DeviceInstanceId),
+ endpointDefinition);
+
+ return swdCreationResult;
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_ERROR,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Aggregate UMP endpoint creation failed", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name"),
+ TraceLoggingHResult(swdCreationResult, MIDI_TRACE_EVENT_HRESULT_FIELD)
+ );
+
+ return swdCreationResult;
+ }
+}
+
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(
+ std::shared_ptr endpointDefinition
+)
+{
+ RETURN_HR_IF_NULL(E_INVALIDARG, endpointDefinition);
+
+ std::shared_ptr parentDevice{ nullptr };
+ RETURN_IF_FAILED(FindExistingParentDeviceDefinitionForEndpoint(endpointDefinition, parentDevice));
+ RETURN_HR_IF_NULL(E_UNEXPECTED, parentDevice);
+
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name")
+ );
+
+ // we require at least one valid pin
+ RETURN_HR_IF(E_INVALIDARG, endpointDefinition->MidiPins.size() < 1);
+
+ std::vector interfaceDevProperties{ };
+
+ std::vector pinMapPropertyData;
+ std::vector groupTerminalBlocks{ };
+ std::vector nameTablePropertyData;
+
+ // update the pin map to have all the existing pins
+ // plus the new pins. Update Group Terminal Blocks at the same time.
+ RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData(
+ endpointDefinition,
+ pinMapPropertyData,
+ groupTerminalBlocks));
+
+
+ // Write Pin Map Property
+ // =====================================================
+ interfaceDevProperties.push_back({ { DEVPKEY_KsAggMidiGroupPinMap, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_BINARY, static_cast(pinMapPropertyData.size()), pinMapPropertyData.data() });
+
+
+ // Write Group Terminal Block Property
+ // =====================================================
+
+ std::vector groupTerminalBlockData;
+
+ if (internal::WriteGroupTerminalBlocksToPropertyDataPointer(groupTerminalBlocks, groupTerminalBlockData))
+ {
+ interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr },
+ DEVPROP_TYPE_BINARY, (ULONG)groupTerminalBlockData.size(), (PVOID)groupTerminalBlockData.data() });
+ }
+ else
+ {
+ // write empty data
+ }
+
+
+ // Fold in custom properties, including MIDI 1 port names and naming approach
+ // ===============================================================================
+
+ WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{};
+ matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->EndpointDeviceInstanceId);
+ matchCriteria.UsbVendorId = parentDevice->VID;
+ matchCriteria.UsbProductId = parentDevice->PID;
+ matchCriteria.UsbSerialNumber = parentDevice->SerialNumber;
+ matchCriteria.TransportSuppliedEndpointName = endpointDefinition->EndpointName;
+
+ auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria);
+
+ std::wstring customName{ };
+ std::wstring customDescription{ };
+ if (customProperties != nullptr)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD),
+ TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"),
+ TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count")
+ );
+
+ // this includes image, the Midi 1 naming approach, etc.
+ customProperties->WriteNonCommonProperties(interfaceDevProperties);
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD)
+ );
+ }
+
+ // store the property data for the name table
+ endpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties);
+
+
+ // Write Name table property, folding in the custom names we discovered earlier
+ // ===============================================================================================
+ RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(endpointDefinition, customProperties));
+ endpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties);
+
+ HRESULT updateResult{};
+
+ LOG_IF_FAILED(updateResult = m_midiDeviceManager->UpdateEndpointProperties(
+ endpointDefinition->EndpointDeviceId.c_str(),
+ static_cast(interfaceDevProperties.size()),
+ interfaceDevProperties.data()
+ ));
+
+
+ if (SUCCEEDED(updateResult))
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Aggregate UMP endpoint updated with new filter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointDeviceId.c_str(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD)
+ );
+
+ auto lock = m_activatedEndpointDefinitionsLock.lock();
+
+ // Add to internal endpoint manager
+ m_activatedEndpointDefinitions.insert_or_assign(
+ internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice->DeviceInstanceId),
+ endpointDefinition);
+
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_ERROR,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Aggregate UMP endpoint update failed", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name"),
+ TraceLoggingHResult(updateResult, MIDI_TRACE_EVENT_HRESULT_FIELD)
+ );
+ }
+
+ return updateResult;
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::GetPinName(HANDLE const hFilter, UINT const pinIndex, std::wstring& pinName)
+{
+ std::unique_ptr pinNameData;
+ ULONG pinNameDataSize{ 0 };
+
+ auto pinNameHR = PinPropertyAllocate(
+ hFilter,
+ pinIndex,
+ KSPROPSETID_Pin,
+ KSPROPERTY_PIN_NAME,
+ (PVOID*)&pinNameData,
+ &pinNameDataSize
+ );
+
+ if (SUCCEEDED(pinNameHR) || pinNameHR == HRESULT_FROM_WIN32(ERROR_SET_NOT_FOUND))
+ {
+ // Check to see if the pin has an iJack name
+ if (pinNameDataSize > 0)
+ {
+ pinName = pinNameData.get();
+
+ return S_OK;
+ }
+ }
+
+ return E_FAIL;
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::GetPinDataFlow(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ KSPIN_DATAFLOW& dataFlow)
+{
+ auto dataFlowHR = PinPropertySimple(
+ hFilter,
+ pinIndex,
+ KSPROPSETID_Pin,
+ KSPROPERTY_PIN_DATAFLOW,
+ &dataFlow,
+ sizeof(KSPIN_DATAFLOW)
+ );
+
+ if (SUCCEEDED(dataFlowHR))
+ {
+ return S_OK;
+ }
+
+ return E_FAIL;
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::GetKSDriverSuppliedName(HANDLE hInstantiatedFilter, std::wstring& name)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+ // get the name GUID
+
+ KSCOMPONENTID componentId{};
+ KSPROPERTY prop{};
+ ULONG countBytesReturned{};
+
+ prop.Id = KSPROPERTY_GENERAL_COMPONENTID;
+ prop.Set = KSPROPSETID_General;
+ prop.Flags = KSPROPERTY_TYPE_GET;
+
+ auto hrComponent = SyncIoctl(
+ hInstantiatedFilter,
+ IOCTL_KS_PROPERTY,
+ &prop,
+ sizeof(KSPROPERTY),
+ &componentId,
+ sizeof(KSCOMPONENTID),
+ &countBytesReturned
+ );
+
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
+ {
+ // changed to not log the failure here. Failures are expected for many devices, and it's adding noise to error logs
+ if (FAILED(hrComponent))
+ {
+ return hrComponent;
+ }
+ }
+ else
+ {
+ RETURN_IF_FAILED(hrComponent);
+ }
+
+ componentId.Name; // this is the GUID which points to the registry location with the driver-supplied name
+
+ if (componentId.Name != GUID_NULL)
+ {
+ // we have the GUID where this name is stored, so get the driver-supplied name from the registry
+
+ WCHAR nameFromRegistry[MAX_PATH]{ 0 }; // this should only be MAXPNAMELEN, but if someone tampered with it, could be larger, hence MAX_PATH
+
+ std::wstring regKey = L"SYSTEM\\CurrentControlSet\\Control\\MediaCategories\\" + internal::GuidToString(componentId.Name);
+
+ if (SUCCEEDED(wil::reg::get_value_string_nothrow(HKEY_LOCAL_MACHINE, regKey.c_str(), L"Name", nameFromRegistry)))
+ {
+ name = nameFromRegistry;
+ }
+
+ return S_OK;
+ }
+
+ return E_NOTFOUND;
+}
+
+
+
+#define KS_CATEGORY_AUDIO_GUID L"{6994AD04-93EF-11D0-A3CC-00A0C9223196}"
+
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::ParseParentIdIntoVidPidSerial(
+ std::wstring systemDevicesParentValue,
+ std::shared_ptr parentDevice)
+{
+ RETURN_HR_IF_NULL(E_INVALIDARG, parentDevice);
+
+ if (systemDevicesParentValue.empty())
+ {
+ RETURN_IF_FAILED(E_INVALIDARG);
+ }
+
+
+ // Examples
+ // ---------------------------------------------------------------------------
+ // Parent values with serial: USB\VID_16C0&PID_05E4\ContinuuMini_SN024066
+ // USB\VID_2573&PID_008A\no_serial_number (yes, this is the iSerialNumber in USB :/
+ // 0x03 iSerialNumber "no serial number"
+ // USB\VID_12E6&PID_002C\251959d4f21fc283
+ //
+ // Parent values without serial: USB\VID_F055&PID_0069\8&2858bbac&0&4
+ // USB\VID_2662&PID_000D\8&24eb0394&0&4
+ // USB\VID_05E3&PID_0610\7&2f028fc9&0&4
+ // ROOT\MOTUBUS\0000
+
+ std::wstring parentVal = systemDevicesParentValue.c_str();
+
+ std::wstringstream ss(parentVal);
+ std::wstring usbSection{};
+
+ std::getline(ss, usbSection, static_cast('\\'));
+
+ if (usbSection == L"USB")
+ {
+ // get the VID/PID section
+
+ std::wstring vidPidSection{};
+
+ std::getline(ss, vidPidSection, static_cast('\\'));
+
+ if (!vidPidSection.empty())
+ {
+ std::wstring serialSection{};
+ std::getline(ss, serialSection, static_cast('\\'));
+
+ std::wstring vidPidString1{};
+ std::wstring vidPidString2{};
+
+ std::wstringstream ssVidPid(vidPidSection);
+ std::getline(ssVidPid, vidPidString1, static_cast('&'));
+ std::getline(ssVidPid, vidPidString2, static_cast('&'));
+
+ wchar_t* end{ nullptr };
+
+ // find the VID
+ if (vidPidString1.starts_with(L"VID_"))
+ {
+ parentDevice->VID = static_cast(wcstol(vidPidString1.substr(4).c_str(), &end, 16));
+ }
+ else if (vidPidString2.starts_with(L"VID_"))
+ {
+ parentDevice->VID = static_cast(wcstol(vidPidString2.substr(4).c_str(), &end, 16));
+ }
+
+ // find the PID
+ if (vidPidString1.starts_with(L"PID_"))
+ {
+ parentDevice->PID = static_cast(wcstol(vidPidString1.substr(4).c_str(), &end, 16));
+ }
+ else if (vidPidString2.starts_with(L"PID_"))
+ {
+ parentDevice->PID = static_cast(wcstol(vidPidString2.substr(4).c_str(), &end, 16));
+ }
+
+ // serial numbers with a & in them, are generated by our system
+ // it's possible a vendor may have a serial number with this in it,
+ // but in that case, we just ditch it.
+ if (serialSection.find_first_of('&') == serialSection.npos)
+ {
+ // Windows replaces spaces in the serial number with the underscore.
+ // yes, this will end up catching the few (if any) serials that do
+ // actually include an underscore. However, there are a bunch with spaces.
+ std::replace(serialSection.begin(), serialSection.end(), '_', ' ');
+ parentDevice->SerialNumber = serialSection;
+ }
+ }
+ }
+ else
+ {
+ // not a USB device, or otherwise uses a custom driver. We can't count
+ // on being able to parse the parent id. Example: MOTU has
+ // ROOT\MOTUBUS\0000 as the parent
+ }
+
+ return S_OK;
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::FindPendingEndpointDefinitionForParentDevice(
+ std::wstring parentDeviceInstanceId,
+ std::shared_ptr& endpointDefinition)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent deviec instance id")
+ );
+
+ auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId);
+
+ for (auto const& endpoint : m_pendingEndpointDefinitions)
+ {
+ if (internal::NormalizeDeviceInstanceIdWStringCopy(endpoint->ParentDeviceInstanceId) == cleanParentDeviceInstanceId)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Match found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent deviec instance id")
+ );
+
+ endpointDefinition = endpoint;
+ return S_OK;
+ }
+ }
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent deviec instance id")
+ );
+
+ endpointDefinition = nullptr;
+
+ return E_NOTFOUND;
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::FindActivatedEndpointDefinitionForFilterDevice(
+ std::wstring filterDeviceId,
+ std::shared_ptr& endpointDefinition
+)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDeviceId.c_str(), "filter device id")
+ );
+
+ auto cleanFilterDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(filterDeviceId);
+
+ for (auto const& endpoint : m_activatedEndpointDefinitions)
+ {
+ for (auto const& pin: endpoint.second->MidiPins)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Matching Endpoint found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanFilterDeviceId.c_str(), "filter device id")
+ );
+
+ if (internal::NormalizeEndpointInterfaceIdWStringCopy(pin->FilterDeviceId) == cleanFilterDeviceId)
+ {
+ endpointDefinition = endpoint.second;
+ return S_OK;
+ }
+ }
+ }
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanFilterDeviceId.c_str(), "filter device id")
+ );
+
+ endpointDefinition = nullptr;
+
+ return E_NOTFOUND;
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::FindAllActivatedEndpointDefinitionsForParentDevice(
+ std::wstring parentDeviceInstanceId,
+ std::vector>& endpointDefinitions
+)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId);
+
+ bool found { false };
+
+ for (auto const& endpoint : m_activatedEndpointDefinitions)
+ {
+ if (internal::NormalizeDeviceInstanceIdWStringCopy(endpoint.second->ParentDeviceInstanceId) == cleanParentDeviceInstanceId)
+ {
+ endpointDefinitions.push_back(endpoint.second);
+ found = true;
+ }
+ }
+
+
+ if (found)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"One or more matching endpoints found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ return S_OK;
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"No matches found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ return E_NOTFOUND;
+ }
+
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::FindAllPendingEndpointDefinitionsForParentDevice(
+ std::wstring parentDeviceInstanceId,
+ std::vector>& endpointDefinitions
+)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId);
+
+ bool found{ false };
+
+ for (auto const& endpoint : m_pendingEndpointDefinitions)
+ {
+ if (internal::NormalizeDeviceInstanceIdWStringCopy(endpoint->ParentDeviceInstanceId) == cleanParentDeviceInstanceId)
+ {
+ endpointDefinitions.push_back(endpoint);
+ found = true;
+ }
+ }
+
+
+ if (found)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"One or more matching endpoints found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ return S_OK;
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"No matches found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ return E_NOTFOUND;
+ }
+
+}
+
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::FindExistingParentDeviceDefinitionForEndpoint(
+ std::shared_ptr endpointDefinition,
+ std::shared_ptr& parentDeviceDefinition
+)
+{
+ RETURN_HR_IF_NULL(E_INVALIDARG, endpointDefinition);
+
+ auto cleanDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->ParentDeviceInstanceId);
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Looking for matching parent", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ if (auto parent = m_allParentDeviceDefinitions.find(cleanDeviceInstanceId); parent != m_allParentDeviceDefinitions.end())
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Match found.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ parentDeviceDefinition = parent->second;
+
+ return S_OK;
+ }
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanDeviceInstanceId.c_str(), "parent device instance id")
+ );
+
+ return E_NOTFOUND;
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::FindOrCreateParentDeviceDefinitionForFilterDevice(
+ DeviceInformation filterDevice,
+ std::shared_ptr& parentDeviceDefinition
+)
+{
+ // we require that the System.Devices.DeviceInstanceId property was requested for the passed-in filter device
+ auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L"");
+ RETURN_HR_IF(E_FAIL, deviceInstanceId.empty());
+
+ auto additionalProperties = winrt::single_threaded_vector();
+ additionalProperties.Append(L"System.Devices.DeviceManufacturer");
+ additionalProperties.Append(L"System.Devices.Manufacturer");
+ additionalProperties.Append(L"System.Devices.Parent");
+
+ auto parentDevice = DeviceInformation::CreateFromIdAsync(
+ deviceInstanceId,
+ additionalProperties,
+ winrt::Windows::Devices::Enumeration::DeviceInformationKind::Device).get();
+
+
+ auto lock = m_allParentDeviceDefinitionsLock.lock(); // we lock to avoid having one inserted while we're processing
+
+ auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice.Id().c_str());
+
+ if (auto it = m_allParentDeviceDefinitions.find(cleanParentDeviceInstanceId); it != m_allParentDeviceDefinitions.end())
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Found existing parent device.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent")
+ );
+
+ // we found a matching parent device. Return it.
+ parentDeviceDefinition = it->second;
+
+ return S_OK;
+ }
+
+ // We don't have one, create one and add, and get all the parent device information for it
+ // we still have the map locked, so keep this code fast
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Parent device definition not already created. Creating now.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent")
+ );
+
+ auto newParentDeviceDefinition = std::make_shared();
+ RETURN_HR_IF_NULL(E_OUTOFMEMORY, newParentDeviceDefinition);
+
+ newParentDeviceDefinition->DeviceName = parentDevice.Name();
+ newParentDeviceDefinition->DeviceInstanceId = cleanParentDeviceInstanceId;
+
+ LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(newParentDeviceDefinition->DeviceInstanceId, newParentDeviceDefinition));
+
+ // only some vendor drivers provide an actual manufacturer
+ // and all the in-box drivers just provide the Generic USB Audio string
+ // TODO: Is "Generic USB Audio" a string that is localized? If so, this code will not have the intended effect outside of en-US
+ auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L"");
+ auto manufacturer2 = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.Manufacturer", parentDevice, L"");
+ if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft")
+ {
+ newParentDeviceDefinition->ManufacturerName = manufacturer;
+ }
+ else if (!manufacturer2.empty() && manufacturer2 != L"(Generic USB Audio)" && manufacturer2 != L"Microsoft")
+ {
+ newParentDeviceDefinition->ManufacturerName = manufacturer2;
+ }
+
+ // Do we need to disambiguate this parent because another of the same device already exists?
+
+ uint32_t currentMaxIndex { 0 };
+ bool otherParentsWithSameNameExist{ false };
+
+ for (auto const& existingParent : m_allParentDeviceDefinitions)
+ {
+ if (existingParent.second->DeviceName == newParentDeviceDefinition->DeviceName)
+ {
+ currentMaxIndex = max(currentMaxIndex, existingParent.second->IndexOfDevicesWithThisSameName);
+ otherParentsWithSameNameExist = true;
+ }
+ }
+
+ if (otherParentsWithSameNameExist)
+ {
+ parentDeviceDefinition->IndexOfDevicesWithThisSameName = currentMaxIndex + 1;
+ }
+
+
+ m_allParentDeviceDefinitions[newParentDeviceDefinition->DeviceInstanceId] = newParentDeviceDefinition;
+ parentDeviceDefinition = newParentDeviceDefinition;
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Parent device definition added.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(newParentDeviceDefinition->DeviceInstanceId.c_str(), "key (device instance id)")
+ );
+
+ return S_OK;
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::FindCurrentMaxEndpointIndexForParentDevice(
+ std::shared_ptr parentDeviceDefinition,
+ uint32_t& currentMaxIndex)
+{
+ auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceDefinition->DeviceInstanceId);
+
+ int32_t maxIndex{ -1 };
+ bool found{ false };
+
+ auto activatedLock = m_activatedEndpointDefinitionsLock.lock();
+ auto pendingLock = m_pendingEndpointDefinitionsLock.lock();
+
+ // look through all pending and activated endpoints and find the max.
+ // If the max is 0 or greater, return it and set S_OK.
+ // if no endpoints found, return E_NOTFOUND so the calling code
+ // knows that the 0 is not in use
+
+ for (auto const& ep : m_activatedEndpointDefinitions)
+ {
+ if (ep.second->ParentDeviceInstanceId == cleanParentDeviceInstanceId)
+ {
+ maxIndex++;
+ found = true;
+ }
+ }
+
+ for (auto const& ep : m_pendingEndpointDefinitions)
+ {
+ if (ep->ParentDeviceInstanceId == cleanParentDeviceInstanceId)
+ {
+ maxIndex++;
+ found = true;
+ }
+ }
+
+
+ if (found)
+ {
+ currentMaxIndex = static_cast(maxIndex);
+ return S_OK;
+ }
+ else
+ {
+ return E_NOTFOUND;
+ }
+
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForFilterDevice(
+ DeviceInformation filterDevice,
+ std::shared_ptr& endpointDefinition
+)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id")
+ );
+
+ std::shared_ptr parentDeviceDefinition{ nullptr };
+
+ // this function locks the parent device list for the duration of the call
+ RETURN_IF_FAILED(FindOrCreateParentDeviceDefinitionForFilterDevice(
+ filterDevice,
+ parentDeviceDefinition
+ ));
+
+ RETURN_HR_IF_NULL(E_POINTER, parentDeviceDefinition);
+
+
+ // See if we already have an endpoint with space for the number of groups we're going to add
+
+ std::shared_ptr existingEndpointDefinition { nullptr };
+ if (SUCCEEDED(FindPendingEndpointDefinitionForParentDevice(parentDeviceDefinition->DeviceInstanceId, existingEndpointDefinition)))
+ {
+ RETURN_HR_IF_NULL(E_UNEXPECTED, existingEndpointDefinition);
+
+ // TODO: Need to check to see if this endpoint has enough space for the new pins
+
+ endpointDefinition = existingEndpointDefinition;
+
+ return S_OK;
+ }
+ else
+ {
+
+ // create a new endpoint
+ auto newEndpointDefinition = std::make_shared();
+ RETURN_HR_IF_NULL(E_POINTER, newEndpointDefinition);
+
+ // We need to ensure each endpoint has a unique id. They can't all use the ParentDeviceInstanceId as the
+ // instance id because now some devices will have multiple endpoints. Instead, we need to add a suffix to
+ // that. We need this to be deterministic and not just a random GUID/number, so that device ids have a
+ // chance to match up when next enumerated after a restart or connect/disconnect.
+
+ auto parentLock = m_allParentDeviceDefinitionsLock.lock();
+
+ uint32_t endpointIndexForThisParent{ 0 };
+ if (SUCCEEDED(FindCurrentMaxEndpointIndexForParentDevice(parentDeviceDefinition, endpointIndexForThisParent)))
+ {
+ // increment the number here
+ endpointIndexForThisParent++;
+ }
+
+ newEndpointDefinition->ParentDeviceInstanceId = parentDeviceDefinition->DeviceInstanceId;
+ newEndpointDefinition->EndpointIndexForThisParentDevice = endpointIndexForThisParent;
+
+ // default hash is the device id.
+ std::hash hasher;
+ std::wstring hash;
+ hash = std::to_wstring(hasher(parentDeviceDefinition->DeviceInstanceId));
+
+ if (endpointIndexForThisParent == 0)
+ {
+ newEndpointDefinition->EndpointName = parentDeviceDefinition->DeviceName;
+ newEndpointDefinition->EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash;
+ }
+ else
+ {
+ // pad the string with "0" characters to the left of the number, up to 3 places total.
+ // we +1 so the second one is _002 and not _001
+ newEndpointDefinition->EndpointDeviceInstanceId = std::format(L"{0}{1}_{2:0>3}", TRANSPORT_INSTANCE_ID_PREFIX, hash, endpointIndexForThisParent + 1);
+
+ // add the name disambiguator to the endpoint. We +1 to the index for the same reasons as above.
+ newEndpointDefinition->EndpointName = std::format(L"{0} ({1})", parentDeviceDefinition->DeviceName, endpointIndexForThisParent + 1);
+ //newEndpointDefinition->EndpointName = std::format(L"{1} - {0}", parentDeviceDefinition->DeviceName, endpointIndexForThisParent + 1);
+ }
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Adding pending aggregate UMP endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+ m_pendingEndpointDefinitions.push_back(newEndpointDefinition);
+ endpointDefinition = newEndpointDefinition;
+
+ return S_OK;
+ }
+
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::IncrementAndGetNextGroupIndex(
+ std::shared_ptr definition,
+ MidiFlow dataFlowFromUserPerspective,
+ uint8_t& groupIndex)
+{
+ // the structure is initialized with -1 for the current group index in each direction
+
+ if (dataFlowFromUserPerspective == MidiFlow::MidiFlowIn)
+ {
+ definition->CurrentHighestMidiSourceGroupIndex++;
+ groupIndex = definition->CurrentHighestMidiSourceGroupIndex;
+ }
+ else
+ {
+ definition->CurrentHighestMidiDestinationGroupIndex++;
+ groupIndex = definition->CurrentHighestMidiDestinationGroupIndex;
+ }
+
+ return S_OK;
+}
+
+#define MAX_THREAD_WORKER_WAIT_TIME_MS 20000
+
+_Use_decl_annotations_
+void CMidi2KSAggregateMidiEndpointManager2::EndpointCreationThreadWorker(
+ std::stop_token token)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"EndpointCreationWorker: Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+ while (!token.stop_requested())
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"EndpointCreationWorker: Waiting to be woken up.", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+ // wait to be woken up
+ m_endpointCreationThreadWakeup.wait(MAX_THREAD_WORKER_WAIT_TIME_MS);
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"EndpointCreationWorker: I'm awake now.", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+ if (!token.stop_requested())
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"EndpointCreationWorker: Short nap before checking.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingUInt32(m_individualInterfaceEnumTimeoutMS, "nap period (ms)")
+ );
+
+ // we sleep for this timeout before we check to see if the thread is signaled.
+ // this gives time for an additional interface notification to cause the event
+ // to be reset
+ Sleep(m_individualInterfaceEnumTimeoutMS);
+
+ // if we're still signaled, that means no other pnp notifications came in during the nap, or if they
+ // did, they completed within that nap period
+ if (m_endpointCreationThreadWakeup.is_signaled() && m_pendingEndpointDefinitions.size() > 0)
+ {
+ m_endpointCreationThreadWakeup.ResetEvent(); // it's a manual reset event
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"EndpointCreationWorker: Thread was signaled and pending definition count > 0. Proceed to processing queue.", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+
+ // lock the definitions so we can process them
+ auto lock = m_pendingEndpointDefinitionsLock.lock();
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"EndpointCreationWorker: Processing pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+ while (m_pendingEndpointDefinitions.size() > 0)
+ {
+ // effectively a queue, but because we have to iterate and search
+ // this in other functions, a vector is more appropriate
+ auto ep = m_pendingEndpointDefinitions[0];
+ m_pendingEndpointDefinitions.erase(m_pendingEndpointDefinitions.begin());
+
+ // create the endpoint
+ LOG_IF_FAILED(DeviceCreateMidiUmpEndpoint(ep));
+ }
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"EndpointCreationWorker: Processed all pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+ }
+ }
+ }
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Exit", MIDI_TRACE_EVENT_MESSAGE_FIELD)
+ );
+
+
+}
+
+_Use_decl_annotations_
+bool CMidi2KSAggregateMidiEndpointManager2::ActiveKSAEndpointForDeviceExists(
+ _In_ std::wstring parentDeviceInstanceId)
+{
+ auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str());
+
+ for (auto const& entry : m_activatedEndpointDefinitions)
+ {
+ if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->ParentDeviceInstanceId) == cleanParentDeviceInstanceId)
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+
+bool ShouldSkipOpeningKsPin(_In_ KsHandleWrapper& deviceHandleWrapper, _In_ UINT pinIndex)
+{
+ if (Feature_Servicing_MIDI2FilterCreations::IsEnabled())
+ {
+ std::unique_ptr dataRanges;
+ ULONG dataRangesSize{ 0 };
+
+ // skip this pin if for some reason data ranges aren't valid
+ if (FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT {
+ return RetrieveDataRanges(h, pinIndex, (PKSMULTIPLE_ITEM*)&dataRanges, &dataRangesSize);
+ })))
+ {
+ return true;
+ }
+
+ // Skip this pin if it supports cyclic ump, or if it doesn't support bytestream
+ if (SUCCEEDED(DataRangeSupportsTransport(dataRanges.get(), MidiTransport_CyclicUMP)) ||
+ FAILED(DataRangeSupportsTransport(dataRanges.get(), MidiTransport_StandardByteStream)))
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins(
+ DeviceInformation filterDevice,
+ std::vector>& pinListToAddTo,
+ uint8_t& countMidiSourcePinsAdded,
+ uint8_t& countMidiDestinationPinsAdded
+ )
+{
+ // Wrapper opens the handle internally.
+ KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str());
+ RETURN_IF_FAILED(deviceHandleWrapper.Open());
+
+ countMidiSourcePinsAdded = 0;
+ countMidiDestinationPinsAdded = 0;
+
+ // Driver-supplied name. This is needed for WinMM backwards compatibility
+ std::wstring driverSuppliedName{};
+
+ // Using lamba function to prevent handle from dissapearing when being used.
+ // we don't log the HRESULT because this is not critical and will often fail,
+ // adding unnecessary noise to the error logs
+ deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT {
+ return GetKSDriverSuppliedName(h, driverSuppliedName);
+ });
+
+ // =============================================================================================
+ // Go through all the enumerated pins, looking for a MIDI 1.0 pin
+
+ // enumerate all the pins for this filter
+ ULONG cPins{ 0 };
+
+ RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT {
+ return PinPropertySimple(h, 0, KSPROPSETID_Pin, KSPROPERTY_PIN_CTYPES, &cPins, sizeof(cPins));
+ }));
+
+ ULONG midiInputPinIndexForThisFilter{ 0 };
+ ULONG midiOutputPinIndexForThisFilter{ 0 };
+
+ // process the pins for this filter. Not all will be MIDI pins
+ for (UINT pinIndex = 0; pinIndex < cPins; pinIndex++)
+ {
+ // Check the communication capabilities of the pin so we can fail fast
+ KSPIN_COMMUNICATION communication = (KSPIN_COMMUNICATION)0;
+
+ RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT {
+ return PinPropertySimple(h, pinIndex, KSPROPSETID_Pin, KSPROPERTY_PIN_COMMUNICATION, &communication, sizeof(KSPIN_COMMUNICATION));
+ }));
+
+ // The external connector pin representing the physical connection
+ // has KSPIN_COMMUNICATION_NONE. We can only create on software IO pins which
+ // have a valid communication. Skip connector pins.
+ if (communication == KSPIN_COMMUNICATION_NONE)
+ {
+ continue;
+ }
+
+ if (ShouldSkipOpeningKsPin(deviceHandleWrapper, pinIndex))
+ {
+ continue;
+ }
+
+ // Duplicate the handle to safely pass it to another component or store it.
+ wil::unique_handle handleDupe(deviceHandleWrapper.GetHandle());
+ RETURN_IF_NULL_ALLOC(handleDupe);
+
+ KsHandleWrapper pinHandleWrapper(
+ filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get());
+
+ if (SUCCEEDED(pinHandleWrapper.Open()))
+ {
+ auto pinDefinition = std::make_shared();
+ RETURN_HR_IF_NULL(E_POINTER, pinDefinition);
+
+ //pinDefinition.KSDriverSuppliedName = driverSuppliedName;
+ pinDefinition->PinNumber = pinIndex;
+ pinDefinition->FilterDeviceId = std::wstring{ filterDevice.Id() };
+ pinDefinition->FilterName = std::wstring{ filterDevice.Name() };
+
+ // find the name of the pin (supplied by iJack, and if not available, generated by the stack)
+ std::wstring pinName{ };
+
+ HRESULT pinNameHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT {
+ return GetPinName(h, pinIndex, pinName);
+ });
+
+ if (SUCCEEDED(pinNameHr))
+ {
+ pinDefinition->PinName = pinName;
+ }
+
+ // get the data flow so we know if this is a MIDI Input or a MIDI Output
+ KSPIN_DATAFLOW dataFlow = (KSPIN_DATAFLOW)0;
+
+ HRESULT dataFlowHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT {
+ return GetPinDataFlow(h, pinIndex, dataFlow);
+ });
+
+
+
+ pinDefinition->DriverSuppliedName = driverSuppliedName;
+
+
+ if (SUCCEEDED(dataFlowHr))
+ {
+ if (dataFlow == KSPIN_DATAFLOW_IN)
+ {
+ // MIDI Out (input to device)
+ pinDefinition->PinDataFlow = MidiFlow::MidiFlowIn;
+ pinDefinition->DataFlowFromUserPerspective = MidiFlow::MidiFlowOut; // opposite
+ pinDefinition->PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter);
+
+ midiOutputPinIndexForThisFilter++;
+
+ countMidiDestinationPinsAdded++;
+ }
+ else if (dataFlow == KSPIN_DATAFLOW_OUT)
+ {
+ // MIDI In (output from device)
+ pinDefinition->PinDataFlow = MidiFlow::MidiFlowOut;
+ pinDefinition->DataFlowFromUserPerspective = MidiFlow::MidiFlowIn; // opposite
+ pinDefinition->PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter);
+
+ midiInputPinIndexForThisFilter++;
+
+ countMidiSourcePinsAdded++;
+ }
+
+ pinListToAddTo.push_back(pinDefinition);
+ }
+ else
+ {
+ // this is a failure condition. Move on to next pin
+ LOG_IF_FAILED(dataFlowHr);
+ continue;
+ }
+ }
+ }
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"MIDI 1.0 pins enumerated", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"),
+ TraceLoggingUInt32(static_cast(pinListToAddTo.size()), "Total size of pin list including new pins.")
+ );
+
+
+ return S_OK;
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions(
+ std::wstring filterDeviceid,
+ std::shared_ptr endpointDefinition)
+{
+ // At this point, we need to have *all* the pins for the endpoint, not just this filter
+ for (auto& pinDefinition : endpointDefinition->MidiPins)
+ {
+ if (internal::NormalizeDeviceInstanceIdWStringCopy(pinDefinition->FilterDeviceId) !=
+ internal::NormalizeDeviceInstanceIdWStringCopy(filterDeviceid))
+ {
+ // only process the pins for this filter interface. We don't want to
+ // change anything that has already been built. But we do need the
+ // context of all pins when getting the group index.
+ continue;
+ }
+
+ // Figure out the group index for the pin. This needs the context of the
+ // entire device. Failure to get the group index is fatal
+ RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(
+ endpointDefinition,
+ pinDefinition->DataFlowFromUserPerspective,
+ pinDefinition->GroupIndex
+ ));
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDeviceid.c_str(), "filter device id"),
+ TraceLoggingUInt8(pinDefinition->GroupIndex, "group index"),
+ TraceLoggingWideString(pinDefinition->DataFlowFromUserPerspective ==
+ MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective")
+ );
+
+ // guard against errors resulting in invalid group indexes
+ RETURN_HR_IF(E_UNEXPECTED, !IS_VALID_GROUP_INDEX(pinDefinition->GroupIndex));
+
+
+ std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update
+
+ // Build the name table entry for this individual pin
+ endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver(
+ pinDefinition->GroupIndex,
+ pinDefinition->DataFlowFromUserPerspective,
+ customName,
+ pinDefinition->DriverSuppliedName,
+ pinDefinition->FilterName,
+ pinDefinition->PinName,
+ pinDefinition->PortIndexWithinThisFilterAndDirection
+ );
+ }
+
+ return S_OK;
+}
+
+bool EndpointHasRoomForMoreNewPins(
+ _In_ std::shared_ptr endpoint,
+ _In_ uint8_t countNewSourcePins,
+ _In_ uint8_t countNewDestinationPins)
+{
+
+ uint8_t countFoundSourcePins{ 0 };
+ uint8_t countFoundDestinationPins{ 0 };
+
+ // count the source and destination pins
+
+ for (auto const& pin : endpoint->MidiPins)
+ {
+ if (pin->DataFlowFromUserPerspective == MidiFlow::MidiFlowIn)
+ {
+ countFoundSourcePins++;
+ }
+ else
+ {
+ countFoundDestinationPins++;
+ }
+ }
+
+ if ((countFoundSourcePins + countNewSourcePins <= 16) &&
+ (countFoundDestinationPins + countNewDestinationPins <= 16))
+ {
+ return true;
+ }
+
+ return false;
+
+}
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded(
+ DeviceWatcher /* watcher */,
+ DeviceInformation filterDevice
+)
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDevice.Id().c_str(), "added interface")
+ );
+
+
+ // 1. Get the list of pins that are the right category for us to try to activate
+ // 2. Loop through and build final list of all MIDI 1.0 source and destination pins
+ // 3. Do we already have an activated endpoint for this device?
+ // 3.1 If we do, then see if it has room for these pins.
+ // 3.1.1 If it has room, then add these pins to the endpoint
+ // 3.1.2 If it doesn't have room, then build a new endpoint for this device
+ // and add that endpoint to the pending endpoints list
+ // 3.2 If we do not have an activated endpoint, see if we have a pending endpoint
+ //
+
+
+ std::wstring transportCode(TRANSPORT_CODE);
+
+ // Wrapper opens the handle internally.
+ KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str());
+ RETURN_IF_FAILED(deviceHandleWrapper.Open());
+
+ // get all the MIDI 1 pins. These are only partially processed because some things
+ // like group index and naming require the full context of all filters/pins on the
+ // parent device. We want to get these before we try creating parents or endpoints
+ std::vector> pinList{ };
+ uint8_t countEnumeratedMidiSourcePins{ 0 };
+ uint8_t countEnumeratedMidiDestinationPins{ 0 };
+ RETURN_IF_FAILED(GetMidi1FilterPins(filterDevice, pinList, countEnumeratedMidiSourcePins, countEnumeratedMidiDestinationPins));
+
+ if (pinList.size() == 0)
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"No MIDI 1.0 pins for this filter device", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id")
+ );
+
+ return S_OK;
+ }
+ else if (pinList.size() > 32)
+ {
+ // we don't support more than 32 pins, 16 in, 16 out, on a single endpoint.
+ // We also don't support splitting a filter across more than one endpoint
+ // so we can only enumerate the first 16 in each direction
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_ERROR,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Too many MIDI pins for this filter. Maximum of 16 in and 16 out allowed per KS filter.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id")
+ );
+
+ RETURN_IF_FAILED(E_FAIL);
+ }
+
+
+ // We have MIDI 1.0 pins to process, so we'll need to find or create a parent device
+ // and also find or create an endpoint under that parent, which has sufficient room
+ // for these pins.
+
+
+ // ===================================================================
+ // Find or create a parent device definition
+
+ auto parentInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L"");
+ RETURN_HR_IF(E_FAIL, parentInstanceId.empty());
+
+ std::shared_ptr parentDeviceDefinition{ nullptr };
+
+ RETURN_IF_FAILED(FindOrCreateParentDeviceDefinitionForFilterDevice(
+ filterDevice,
+ parentDeviceDefinition
+ ));
+
+ std::vector> foundEndpoints{};
+
+ // do we already have one or more pending endpoints for this?
+ if (SUCCEEDED(FindAllPendingEndpointDefinitionsForParentDevice(parentInstanceId.c_str(), foundEndpoints)))
+ {
+ std::shared_ptr existingPendingEndpointDefinition{ nullptr };
+
+ // find an endpoint with room for another interface with pins.
+ // We're going by the pin counts returned when we enumerated
+ // all pins for this interface
+
+ // check the latest endpoint first
+
+ for (size_t i = foundEndpoints.size() - 1; i >= 0; i--)
+ {
+ auto ep = foundEndpoints[i];
+
+ if (EndpointHasRoomForMoreNewPins(ep, countEnumeratedMidiSourcePins, countEnumeratedMidiDestinationPins))
+ {
+ existingPendingEndpointDefinition = ep;
+
+ break;
+ }
+ }
+
+ // if this check fails, we fall through to creating a new endpoint definition
+ if (existingPendingEndpointDefinition != nullptr)
+ {
+ //// first MIDI 1 pin we're processing for this interface
+ //RETURN_IF_FAILED(FindActivatedEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition));
+ //RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition);
+
+ // add our new pins into the existing endpoint definition
+ existingPendingEndpointDefinition->MidiPins.insert(existingPendingEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end());
+ RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), existingPendingEndpointDefinition));
+
+ return S_OK;
+ }
+
+ }
+ else if (ActiveKSAEndpointForDeviceExists(parentInstanceId.c_str()))
+ {
+ std::shared_ptr existingActivatedEndpointDefinition{ nullptr };
+
+ // check to see if we already have any activated endpoints for this device
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"KSA endpoint for this filter already activated. Updating it.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"),
+ TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id")
+ );
+
+
+ // TODO: Potential problem with this code:
+ // Because it's just using pin counts, after some add/removes
+ // the group indexes could change. Is that ok? Seems not.
+ // need to see what impact that will have on enumerated MIDI ports
+
+ if (SUCCEEDED(FindAllActivatedEndpointDefinitionsForParentDevice(parentInstanceId.c_str(), foundEndpoints)))
+ {
+ // find an endpoint with room for another interface with pins.
+ // We're going by the pin counts returned when we enumerated
+ // all pins for this interface
+
+ // check the latest endpoint first
+
+ for (size_t i = foundEndpoints.size()-1; i >= 0; i--)
+ {
+ auto ep = foundEndpoints[i];
+
+ if (EndpointHasRoomForMoreNewPins(ep, countEnumeratedMidiSourcePins, countEnumeratedMidiDestinationPins))
+ {
+ existingActivatedEndpointDefinition = ep;
+
+ break;
+ }
+
+ }
+ }
+
+ // if this check fails, we fall through to creating a new endpoint definition
+ if (existingActivatedEndpointDefinition != nullptr)
+ {
+ //// first MIDI 1 pin we're processing for this interface
+ //RETURN_IF_FAILED(FindActivatedEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition));
+ //RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition);
+
+ // add our new pins into the existing endpoint definition
+ existingActivatedEndpointDefinition->MidiPins.insert(existingActivatedEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end());
+ RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), existingActivatedEndpointDefinition));
+
+ RETURN_IF_FAILED(DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition));
+
+ return S_OK;
+ }
+ }
+ else
+ {
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_VERBOSE,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Endpoint for this device does not already exist. Proceed to creating new one.", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"),
+ TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id")
+ );
+ }
+
+
+ // ===================================================================
+ // Create the endpoint
+
+ std::shared_ptr endpointDefinition{ nullptr };
+
+ RETURN_IF_FAILED(FindOrCreatePendingEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition));
+ RETURN_HR_IF_NULL(E_POINTER, endpointDefinition);
+
+ // add our new pins
+ endpointDefinition->MidiPins.insert(endpointDefinition->MidiPins.end(), pinList.begin(), pinList.end());
+ pinList.clear(); // just make sure we don't use this one, accidentally
+
+ RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), endpointDefinition));
+
+ // we have an endpoint definition
+ m_endpointCreationThreadWakeup.SetEvent();
+
+ return S_OK;
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved(
+ DeviceWatcher watcher,
+ DeviceInformationUpdate deviceInterfaceUpdate
+)
+{
+ UNREFERENCED_PARAMETER(watcher);
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
+ TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "removed interface")
+ );
+
+ std::wstring removedFilterDeviceId{ internal::NormalizeDeviceInstanceIdWStringCopy(deviceInterfaceUpdate.Id().c_str()) };
+
+ // find an active device with this filter
+
+ std::shared_ptr endpointDefinition{ nullptr };
+
+ for (auto& endpointListIterator : m_activatedEndpointDefinitions)
+ {
+ // check pins for this filter
+ for (auto& pin: endpointListIterator.second->MidiPins)
+ {
+ if (internal::NormalizeDeviceInstanceIdWStringCopy(pin->FilterDeviceId) == removedFilterDeviceId)
+ {
+ endpointDefinition = endpointListIterator.second;
+ break;
+ }
+ }
+
+ }
+
+ if (endpointDefinition != nullptr)
+ {
+ bool done { false };
+
+ while (!done)
+ {
+ auto foundIt = std::find_if(
+ endpointDefinition->MidiPins.begin(),
+ endpointDefinition->MidiPins.end(),
+ [&removedFilterDeviceId](std::shared_ptr pin) { return internal::NormalizeDeviceInstanceIdWStringCopy(pin->FilterDeviceId) == removedFilterDeviceId; }
+ );
+
+ if (foundIt != endpointDefinition->MidiPins.end())
+ {
+ // erase the pin definition with this
+ endpointDefinition->MidiPins.erase(foundIt);
+ }
+ else
+ {
+ // we've removed all the pins for this interface
+ done = true;
+ }
+ }
+
+ if (endpointDefinition->MidiPins.size() > 0)
+ {
+ // we've removed all the pins for this interface, but there are still
+ // pins left, so now it's time to update the endpoint
+
+
+ // TODO: Need to cache the name from the driver/registry so we don't have to do a lookup here.
+
+ std::shared_ptr parentDeviceDefinition{ nullptr };
+
+ if (SUCCEEDED(FindExistingParentDeviceDefinitionForEndpoint(endpointDefinition, parentDeviceDefinition)))
+ {
+ RETURN_HR_IF_NULL(E_UNEXPECTED, parentDeviceDefinition);
+
+ // update remaining pins in existing endpoint definition
+ RETURN_IF_FAILED(UpdateNewPinDefinitions(removedFilterDeviceId, endpointDefinition));
+ RETURN_IF_FAILED(DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(endpointDefinition));
+ }
+ else
+ {
+ RETURN_IF_FAILED(E_NOTFOUND);
+ }
+ }
+ else
+ {
+ auto lock = m_activatedEndpointDefinitionsLock.lock();
+
+ // notify the device manager using the InstanceId for this midi device
+ RETURN_IF_FAILED(m_midiDeviceManager->RemoveEndpoint(
+ internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->EndpointDeviceInstanceId).c_str()));
+
+ // remove the endpoint from the list
+
+ m_activatedEndpointDefinitions.erase(internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->ParentDeviceInstanceId));
+ }
+ }
+
+ return S_OK;
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceUpdated(
+ DeviceWatcher watcher,
+ DeviceInformationUpdate deviceInterfaceUpdate
+)
+{
+ UNREFERENCED_PARAMETER(watcher);
+
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "updated interface")
+ );
+
+ // Flow for interface UPDATED
+ // - Check for any name changes
+ // - If any relevant changes recalculate GTBs and Name table as above and update properties
+ //
+
+
+
+ return S_OK;
+}
+
+
+
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::OnDeviceWatcherStopped(DeviceWatcher watcher, winrt::Windows::Foundation::IInspectable)
+{
+ UNREFERENCED_PARAMETER(watcher);
+
+ m_EnumerationCompleted.SetEvent();
+ return S_OK;
+}
+
+_Use_decl_annotations_
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::OnEnumerationCompleted(DeviceWatcher watcher, winrt::Windows::Foundation::IInspectable)
+{
+ UNREFERENCED_PARAMETER(watcher);
+
+ m_EnumerationCompleted.SetEvent();
+ return S_OK;
+}
+
+
+_Use_decl_annotations_
+winrt::hstring CMidi2KSAggregateMidiEndpointManager2::FindMatchingInstantiatedEndpoint(
+ WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria& criteria)
+{
+ criteria.Normalize();
+
+ for (auto const& def : m_activatedEndpointDefinitions)
+ {
+ WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{};
+
+ available.DeviceInstanceId = def.second->EndpointDeviceInstanceId;
+ available.EndpointDeviceId = def.second->EndpointDeviceId;
+ available.TransportSuppliedEndpointName = def.second->EndpointName;
+
+ std::shared_ptr parent { nullptr };
+
+ if (SUCCEEDED(FindExistingParentDeviceDefinitionForEndpoint(def.second, parent)))
+ {
+ available.UsbVendorId = parent->VID;
+ available.UsbProductId = parent->PID;
+ available.UsbSerialNumber = parent->SerialNumber;
+ available.DeviceManufacturerName = parent->ManufacturerName;
+ }
+
+
+ if (available.Matches(criteria))
+ {
+ return available.EndpointDeviceId;
+ }
+ }
+
+ return L"";
+}
+
+HRESULT
+CMidi2KSAggregateMidiEndpointManager2::Shutdown()
+{
+ TraceLoggingWrite(
+ MidiKSAggregateTransportTelemetryProvider::Provider(),
+ MIDI_TRACE_EVENT_INFO,
+ TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingPointer(this, "this")
+ );
+
+ m_endpointCreationThread.request_stop();
+ m_endpointCreationThreadWakeup.SetEvent();
+
+ m_DeviceAdded.revoke();
+ m_DeviceRemoved.revoke();
+ m_DeviceUpdated.revoke();
+ m_DeviceStopped.revoke();
+
+ m_DeviceEnumerationCompleted.revoke();
+
+ m_watcher.Stop();
+
+ uint8_t tries{ 0 };
+ while (m_watcher.Status() != DeviceWatcherStatus::Stopped && tries < 50)
+ {
+ Sleep(100);
+ tries++;
+ }
+
+ TransportState::Current().Shutdown();
+
+ m_midiDeviceManager.reset();
+ m_midiProtocolManager.reset();
+
+ return S_OK;
+}
\ No newline at end of file
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h
new file mode 100644
index 00000000..06ca403a
--- /dev/null
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h
@@ -0,0 +1,216 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License
+// ============================================================================
+// This is part of the Windows MIDI Services App API and should be used
+// in your Windows application via an official binary distribution.
+// Further information: https://aka.ms/midi
+// ============================================================================
+
+#pragma once
+
+#include "MidiEndpointCustomProperties.h"
+#include "MidiEndpointMatchCriteria.h"
+#include "MidiEndpointNameTable.h"
+
+using namespace winrt::Windows::Devices::Enumeration;
+
+#define DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS 250
+#define KSA_INTERFACE_ENUM_TIMEOUT_MS_MINIMUM_VALUE 50
+#define KSA_INTERFACE_ENUM_TIMEOUT_MS_MAXIMUM_VALUE 2500
+#define KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE L"KsaInterfaceEnumTimeoutMS"
+
+struct KsAggregateEndpointMidiPinDefinition2
+{
+ std::wstring FilterDeviceId; // this is also the value needed by WinMM for DRV_QUERYDEVICEINTERFACE
+ std::wstring FilterName;
+
+ std::wstring DriverSuppliedName{}; // value from registry. Required for WinMM classic naming. This was at parent level efore, but loopMIDI and similar register per-interface
+
+ ULONG PinNumber{ 0 };
+ std::wstring PinName;
+
+ MidiFlow PinDataFlow;
+ MidiFlow DataFlowFromUserPerspective;
+
+ uint8_t GroupIndex{ 0 };
+ uint8_t PortIndexWithinThisFilterAndDirection{ 0 }; // not always the same as the group index. Example: MOTU Express 128 with separate filter for each in/out pair
+};
+
+struct KsAggregateEndpointDefinition2
+{
+ std::wstring ParentDeviceInstanceId{};
+
+ std::wstring EndpointDeviceId{};
+
+ std::wstring EndpointName{};
+ std::wstring EndpointDeviceInstanceId{};
+
+ std::vector> MidiPins{ };
+
+ WindowsMidiServicesNamingLib::MidiEndpointNameTable EndpointNameTable{ };
+
+ uint32_t EndpointIndexForThisParentDevice{ 0 };
+
+
+ int8_t CurrentHighestMidiSourceGroupIndex{ -1 };
+ int8_t CurrentHighestMidiDestinationGroupIndex{ -1 };
+
+};
+
+
+class KsAggregateParentDeviceDefinition2
+{
+public:
+ std::wstring DeviceName{};
+ std::wstring DeviceInstanceId{};
+
+ uint32_t IndexOfDevicesWithThisSameName{ 0 }; // for when there are multiple of the same device
+
+
+ uint16_t VID{ 0 }; // USB-only
+ uint16_t PID{ 0 }; // USB-only
+ std::wstring SerialNumber{};
+
+ std::wstring ManufacturerName{};
+};
+
+
+class CMidi2KSAggregateMidiEndpointManager2 :
+ public Microsoft::WRL::RuntimeClass<
+ Microsoft::WRL::RuntimeClassFlags,
+ IMidiEndpointManager>
+{
+public:
+
+ STDMETHOD(Initialize(_In_ IMidiDeviceManager*, _In_ IMidiEndpointProtocolManager*));
+ STDMETHOD(Shutdown)();
+
+ // returns the endpointDeviceInterfaceId for a matching endpoint found in m_availableEndpointDefinitions
+ winrt::hstring FindMatchingInstantiatedEndpoint(_In_ WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria& criteria);
+
+private:
+ DWORD m_individualInterfaceEnumTimeoutMS{ DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS };
+
+ wil::com_ptr_nothrow m_midiDeviceManager;
+ wil::com_ptr_nothrow m_midiProtocolManager;
+
+ HRESULT OnFilterDeviceInterfaceAdded(_In_ DeviceWatcher, _In_ DeviceInformation);
+ HRESULT OnFilterDeviceInterfaceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate);
+ HRESULT OnFilterDeviceInterfaceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate);
+
+ HRESULT OnEnumerationCompleted(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable);
+ HRESULT OnDeviceWatcherStopped(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable);
+
+
+ // key is parent device instance id. These don't get "activated" so we keep a single list
+ std::map> m_allParentDeviceDefinitions;
+ wil::critical_section m_allParentDeviceDefinitionsLock;
+
+ // these are all endpoints which have not yet been activated
+ std::vector> m_pendingEndpointDefinitions;
+ wil::critical_section m_pendingEndpointDefinitionsLock;
+
+ // key is parent device instance id
+ std::map> m_activatedEndpointDefinitions;
+ wil::critical_section m_activatedEndpointDefinitionsLock;
+
+
+
+ bool ActiveKSAEndpointForDeviceExists(
+ _In_ std::wstring deviceInstanceId);
+
+ HRESULT ParseParentIdIntoVidPidSerial(
+ _In_ std::wstring systemDevicesParentValue,
+ _In_ std::shared_ptr parentDevice);
+
+
+ HRESULT FindActivatedEndpointDefinitionForFilterDevice(
+ _In_ std::wstring filterDeviceId,
+ _Inout_ std::shared_ptr&);
+
+ HRESULT FindAllActivatedEndpointDefinitionsForParentDevice(
+ _In_ std::wstring parentDeviceInstanceId,
+ _Inout_ std::vector>& endpointDefinitions);
+
+ HRESULT FindAllPendingEndpointDefinitionsForParentDevice(
+ _In_ std::wstring parentDeviceInstanceId,
+ _Inout_ std::vector>& endpointDefinitions);
+
+
+ HRESULT FindPendingEndpointDefinitionForParentDevice(
+ _In_ std::wstring parentDeviceInstanceId,
+ _Inout_ std::shared_ptr&);
+
+ HRESULT FindExistingParentDeviceDefinitionForEndpoint(
+ _In_ std::shared_ptr endpointDefinition,
+ _Inout_ std::shared_ptr& parentDeviceDefinition);
+
+ HRESULT FindOrCreateParentDeviceDefinitionForFilterDevice(
+ _In_ DeviceInformation filterDevice,
+ _Inout_ std::shared_ptr& parentDeviceDefinition);
+
+ HRESULT FindOrCreatePendingEndpointDefinitionForFilterDevice(
+ _In_ DeviceInformation,
+ _Inout_ std::shared_ptr&);
+
+
+ HRESULT FindCurrentMaxEndpointIndexForParentDevice(
+ _In_ std::shared_ptr parentDeviceDefinition,
+ _In_ uint32_t& currentMaxIndex);
+
+
+ HRESULT GetPinName(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ std::wstring& pinName);
+ HRESULT GetPinDataFlow(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ KSPIN_DATAFLOW& dataFlow);
+
+ HRESULT GetMidi1FilterPins(
+ _In_ DeviceInformation,
+ _In_ std::vector>&,
+ _Inout_ uint8_t& countMidiSourcePinsAdded,
+ _Inout_ uint8_t& countMidiDestinationPinsAdded);
+
+ HRESULT GetKSDriverSuppliedName(_In_ HANDLE hFilter, _Inout_ std::wstring& name);
+
+
+ HRESULT IncrementAndGetNextGroupIndex(
+ _In_ std::shared_ptr definition,
+ _In_ MidiFlow dataFlowFromUserPerspective,
+ _In_ uint8_t& groupIndex);
+
+ HRESULT UpdateNewPinDefinitions(
+ _In_ std::wstring filterDeviceid,
+ _In_ std::shared_ptr endpointDefinition);
+
+ HRESULT BuildPinsAndGroupTerminalBlocksPropertyData(
+ _In_ std::shared_ptr masterEndpointDefinition,
+ _In_ std::vector& pinMapPropertyData,
+ _In_ std::vector& groupTerminalBlocks);
+
+ HRESULT UpdateNameTableWithCustomProperties(
+ _In_ std::shared_ptr masterEndpointDefinition,
+ _In_ std::shared_ptr customProperties);
+
+
+ // these two functions actually update the software devices in Windows
+
+ HRESULT DeviceCreateMidiUmpEndpoint(
+ _In_ std::shared_ptr masterEndpointDefinition);
+
+ HRESULT DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(
+ _In_ std::shared_ptr masterEndpointDefinition);
+
+
+ wil::unique_event_nothrow m_endpointCreationThreadWakeup;
+ std::jthread m_endpointCreationThread;
+ void EndpointCreationThreadWorker(_In_ std::stop_token token);
+
+
+ DeviceWatcher m_watcher{0};
+ winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Added_revoker m_DeviceAdded;
+ winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Removed_revoker m_DeviceRemoved;
+ winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Updated_revoker m_DeviceUpdated;
+ winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Stopped_revoker m_DeviceStopped;
+ winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::EnumerationCompleted_revoker m_DeviceEnumerationCompleted;
+ wil::unique_event m_EnumerationCompleted{wil::EventOptions::None};
+
+
+};
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp
index a903ec2c..4862ac3d 100644
--- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp
@@ -30,6 +30,7 @@ CMidi2KSAggregateMidiInProxy::Initialize(
TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
+ TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD),
TraceLoggingUInt32(pinId, "Pin id"),
TraceLoggingUInt8(groupIndex, "Group index")
@@ -117,7 +118,7 @@ CMidi2KSAggregateMidiInProxy::Callback(
RETURN_HR_IF_NULL(E_POINTER, m_callback);
RETURN_HR_IF_NULL(E_POINTER, m_bs2UmpTransform);
-#ifndef _DEBUG
+#ifdef _DEBUG
TraceLoggingWrite(
MidiKSAggregateTransportTelemetryProvider::Provider(),
MIDI_TRACE_EVENT_VERBOSE,
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp
index f3339531..3f9ada05 100644
--- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp
@@ -44,14 +44,29 @@ CMidi2KSAggregateTransport::Activate(
TraceLoggingWideString(L"IMidiEndpointManager", MIDI_TRACE_EVENT_INTERFACE_FIELD)
);
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
+ {
+ // check to see if this is the first time we're creating the endpoint manager. If so, create it.
+ if (TransportState::Current().GetEndpointManager2() == nullptr)
+ {
+ TransportState::Current().ConstructEndpointManager();
+ }
- // check to see if this is the first time we're creating the endpoint manager. If so, create it.
- if (TransportState::Current().GetEndpointManager() == nullptr)
+ RETURN_IF_FAILED(TransportState::Current().GetEndpointManager2()->QueryInterface(iid, activatedInterface));
+ }
+ else
{
- TransportState::Current().ConstructEndpointManager();
+ // check to see if this is the first time we're creating the endpoint manager. If so, create it.
+ if (TransportState::Current().GetEndpointManager() == nullptr)
+ {
+ TransportState::Current().ConstructEndpointManager();
+ }
+
+ RETURN_IF_FAILED(TransportState::Current().GetEndpointManager()->QueryInterface(iid, activatedInterface));
}
- RETURN_IF_FAILED(TransportState::Current().GetEndpointManager()->QueryInterface(iid, activatedInterface));
+
+
}
else if (__uuidof(IMidiTransportConfigurationManager) == iid)
{
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.rc b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.rc
index 232b8b0c..70183e09 100644
--- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.rc
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.rc
@@ -53,8 +53,8 @@ END
//
VS_VERSION_INFO VERSIONINFO
- FILEVERSION 1,0,15,0
- PRODUCTVERSION 1,0,15,0
+ FILEVERSION 1,0,16,0
+ PRODUCTVERSION 1,0,16,0
FILEFLAGSMASK VS_FFI_FILEFLAGSMASK
#ifdef _DEBUG
FILEFLAGS VS_FF_DEBUG
@@ -71,12 +71,12 @@ BEGIN
BEGIN
VALUE "CompanyName", "Microsoft Corporation"
VALUE "FileDescription", "Windows MIDI Services KSA Transport Plugin"
- VALUE "FileVersion", "1.0.15.0"
+ VALUE "FileVersion", "1.0.16.0"
VALUE "LegalCopyright", "Copyright (c) Microsoft Corporation. All rights reserved."
VALUE "InternalName", "Midi2.KSAggregateTransport.dll"
VALUE "OriginalFilename", "Midi2.KSAggregateTransport.dll"
VALUE "ProductName", "Microsoft Windows MIDI Services"
- VALUE "ProductVersion", "1.0.15.0"
+ VALUE "ProductVersion", "1.0.16.0"
VALUE "OLESelfRegister", ""
END
END
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj
index 822f30a5..145bc00c 100644
--- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj
@@ -359,6 +359,7 @@
+
@@ -369,7 +370,7 @@
-
+
@@ -380,6 +381,7 @@
+
@@ -392,7 +394,7 @@
-
+
diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters
index 69c3f79d..8dcaaf34 100644
--- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters
+++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters
@@ -39,7 +39,7 @@
Source Files
-
+
Source Files
@@ -48,6 +48,9 @@
Source Files
+
+ Source Files
+
@@ -94,7 +97,7 @@
Header Files
-
+
Header Files
@@ -103,6 +106,9 @@
Header Files
+
+ Header Files
+
diff --git a/src/api/Transport/KSAggregateTransport/TransportState.cpp b/src/api/Transport/KSAggregateTransport/TransportState.cpp
index da0e8411..2e723832 100644
--- a/src/api/Transport/KSAggregateTransport/TransportState.cpp
+++ b/src/api/Transport/KSAggregateTransport/TransportState.cpp
@@ -25,12 +25,22 @@ TransportState& TransportState::Current()
HRESULT
TransportState::ConstructEndpointManager()
{
- RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&m_endpointManager));
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
+ {
+ RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&m_endpointManager2));
+ }
+ else
+ {
+ RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&m_endpointManager));
+ }
+
return S_OK;
}
+
+
HRESULT
TransportState::ConstructConfigurationManager()
{
diff --git a/src/api/Transport/KSAggregateTransport/TransportState.h b/src/api/Transport/KSAggregateTransport/TransportState.h
index 177f10f3..3b7928de 100644
--- a/src/api/Transport/KSAggregateTransport/TransportState.h
+++ b/src/api/Transport/KSAggregateTransport/TransportState.h
@@ -23,7 +23,27 @@ class TransportState
wil::com_ptr GetEndpointManager()
{
- return m_endpointManager;
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
+ {
+ return nullptr;
+ }
+ else
+ {
+ return m_endpointManager;
+ }
+ }
+
+ // for Feature_Servicing_MIDI2VirtualPortDriversFix
+ wil::com_ptr GetEndpointManager2()
+ {
+ if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled())
+ {
+ return m_endpointManager2;
+ }
+ else
+ {
+ return nullptr;
+ }
}
wil::com_ptr GetConfigurationManager()
@@ -51,6 +71,10 @@ class TransportState
wil::com_ptr m_endpointManager;
+
+ // for Feature_Servicing_MIDI2VirtualPortDriversFix
+ wil::com_ptr m_endpointManager2 { nullptr };
+
wil::com_ptr m_configurationManager;
};
\ No newline at end of file
diff --git a/src/api/Transport/KSAggregateTransport/pch.h b/src/api/Transport/KSAggregateTransport/pch.h
index 32e03421..5362975c 100644
--- a/src/api/Transport/KSAggregateTransport/pch.h
+++ b/src/api/Transport/KSAggregateTransport/pch.h
@@ -36,6 +36,8 @@
#include
#include
+#include
+#include
#define _ATL_APARTMENT_THREADED
@@ -111,8 +113,10 @@ namespace internal = ::WindowsMidiServicesInternal;
#include "Midi2UMP2BSTransform.h"
#include "Midi2UMP2BSTransform_i.c"
+#include "Feature_Servicing_MIDI2VirtualPortDriversFix.h"
class CMidi2KSAggregateMidiEndpointManager;
+class CMidi2KSAggregateMidiEndpointManager2;
class CMidi2KSAggregateMidiInProxy;
class CMidi2KSAggregateMidiOutProxy;
class CMidi2KSAggregateMidiConfigurationManager;
@@ -125,6 +129,7 @@ class TransportState;
#include "Midi2.KSAggregateMidi.h"
#include "Midi2.KSAggregateMidiBidi.h"
#include "Midi2.KSAggregateMidiEndpointManager.h"
+#include "Midi2.KSAggregateMidiEndpointManager2.h"
#include "Midi2.KSAggregateMidiConfigurationManager.h"
#include "Midi2.KSAggregateMidiPluginMetadataProvider.h"