From 3ccd21dd6c4959f29e3f55339cb379134a1cc063 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Sat, 31 Jan 2026 11:56:51 -0500 Subject: [PATCH 01/18] Start working on issue 835 --- .../Midi2.KSAggregateMidiEndpointManager.cpp | 32 +++++++++++++++++++ .../Midi2.KSAggregateMidiEndpointManager.h | 4 +++ 2 files changed, 36 insertions(+) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index b2a19b9fb..56e5abbc2 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -737,6 +737,38 @@ ParseParentIdIntoVidPidSerial( + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceAdded( + DeviceWatcher watcher, + DeviceInformation deviceInterface +) +{ + UNREFERENCED_PARAMETER(watcher); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(deviceInterface.Id().c_str(), "added interface") + ); + + + // Check to see if the interface has pins we want + + // if this has MIDI 1 pins, check to see if we already have an entry for the device + // create the device if we don't already have it + // get the USB vid/pid from the parent device, if it is a USB device + + + + return S_OK; +} + + _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded( diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h index ee365894a..c82c2a877 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h @@ -77,6 +77,10 @@ class CMidi2KSAggregateMidiEndpointManager : HRESULT OnDeviceStopped(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); HRESULT OnEnumerationCompleted(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); + // new interface-based approach + HRESULT OnDeviceInterfaceAdded(_In_ DeviceWatcher watcher, _In_ DeviceInformation deviceInterface); + + wil::com_ptr_nothrow m_midiDeviceManager; wil::com_ptr_nothrow m_midiProtocolManager; From 208f42882c9faf595c58585804a1c8418a33f4f9 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Fri, 6 Feb 2026 23:05:57 -0500 Subject: [PATCH 02/18] Working on KSA bug --- build/replace_just_ksa_x64.bat | 21 + .../Midi2.KSAggregateMidi.cpp | 199 +++++--- .../Midi2.KSAggregateMidiEndpointManager.cpp | 440 ++++++++++++++++-- .../Midi2.KSAggregateMidiEndpointManager.h | 10 +- src/api/Transport/KSAggregateTransport/pch.h | 8 + 5 files changed, 590 insertions(+), 88 deletions(-) create mode 100644 build/replace_just_ksa_x64.bat diff --git a/build/replace_just_ksa_x64.bat b/build/replace_just_ksa_x64.bat new file mode 100644 index 000000000..e1819858b --- /dev/null +++ b/build/replace_just_ksa_x64.bat @@ -0,0 +1,21 @@ +@echo off +echo This must be run as administrator. + +set servicepath="%ProgramFiles%\Windows MIDI Services\Service" +set apipath="%ProgramFiles%\Windows MIDI Services\API" +set dmppath="%ProgramFiles%\Windows MIDI Services\" +set buildoutput="%midi_repo_root%src\api\VSFiles\x64\Release" + + + +echo Stopping midisrv +net stop midisrv + +echo Copying KSA Transport +mkdir %servicepath% +copy /Y %buildoutput%\Midi2.KSAggregateTransport.dll %servicepath% +regsvr32 %servicepath%\Midi2.KSAggregateTransport.dll + +net start midisrv + +pause \ No newline at end of file diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp index cffab8444..372982727 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp @@ -205,40 +205,84 @@ CMidi2KSAggregateMidi::Initialize( wil::com_ptr_nothrow proxy; RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&proxy)); - auto initResult = - proxy->Initialize( - endpointDeviceInterfaceId, - handleDupe.get(), - pinMapEntry->PinId, - requestedBufferSize, - mmCssTaskId, - m_callback, - context, - pinMapEntry->GroupIndex - ); - - if (SUCCEEDED(initResult)) + // needed for internal consumption. Gary to replace this with feature enablement check + // defined in pch.h + if (NewMidiFeatureUpdateKsa2603Enabled()) { - m_midiInDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy)); + auto initResult = + proxy->Initialize( + filterInterfaceId.c_str(), + handleDupe.get(), + pinMapEntry->PinId, + requestedBufferSize, + mmCssTaskId, + m_callback, + context, + pinMapEntry->GroupIndex + ); + + if (SUCCEEDED(initResult)) + { + m_midiInDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy)); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Unable to initialize Midi Input proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD), + TraceLoggingUInt32(requestedBufferSize, "buffer size"), + TraceLoggingUInt32(pinMapEntry->PinId, "pin id"), + TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"), + TraceLoggingWideString(filterInterfaceId.c_str(), "filter") + ); + + RETURN_IF_FAILED(initResult); + } } else { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_ERROR, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_ERROR), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Unable to initialize Midi Input proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD), - TraceLoggingUInt32(requestedBufferSize, "buffer size"), - TraceLoggingUInt32(pinMapEntry->PinId, "pin id"), - TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"), - TraceLoggingWideString(filterInterfaceId.c_str(), "filter") - ); - - RETURN_IF_FAILED(initResult); + auto initResult = + proxy->Initialize( + endpointDeviceInterfaceId, + handleDupe.get(), + pinMapEntry->PinId, + requestedBufferSize, + mmCssTaskId, + m_callback, + context, + pinMapEntry->GroupIndex + ); + + if (SUCCEEDED(initResult)) + { + m_midiInDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy)); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Unable to initialize Midi Input proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD), + TraceLoggingUInt32(requestedBufferSize, "buffer size"), + TraceLoggingUInt32(pinMapEntry->PinId, "pin id"), + TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"), + TraceLoggingWideString(filterInterfaceId.c_str(), "filter") + ); + + RETURN_IF_FAILED(initResult); + } } + + } else if (pinMapEntry->PinDataFlow == MidiFlow::MidiFlowIn) { @@ -247,38 +291,79 @@ CMidi2KSAggregateMidi::Initialize( wil::com_ptr_nothrow proxy; RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&proxy)); - auto initResult = - proxy->Initialize( - endpointDeviceInterfaceId, - handleDupe.get(), - pinMapEntry->PinId, - requestedBufferSize, - mmCssTaskId, - context, - pinMapEntry->GroupIndex - ); - - if (SUCCEEDED(initResult)) + // needed for internal consumption. Gary to replace this with feature enablement check + // defined in pch.h + if (NewMidiFeatureUpdateKsa2603Enabled()) { - m_midiOutDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy)); + auto initResult = + proxy->Initialize( + filterInterfaceId.c_str(), + handleDupe.get(), + pinMapEntry->PinId, + requestedBufferSize, + mmCssTaskId, + context, + pinMapEntry->GroupIndex + ); + + if (SUCCEEDED(initResult)) + { + m_midiOutDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy)); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Unable to initialize Midi Output proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD), + TraceLoggingUInt32(requestedBufferSize, "buffer size"), + TraceLoggingUInt32(pinMapEntry->PinId, "pin id"), + TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"), + TraceLoggingWideString(filterInterfaceId.c_str(), "filter") + ); + + RETURN_IF_FAILED(initResult); + } } else { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_ERROR, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_ERROR), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Unable to initialize Midi Output proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD), - TraceLoggingUInt32(requestedBufferSize, "buffer size"), - TraceLoggingUInt32(pinMapEntry->PinId, "pin id"), - TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"), - TraceLoggingWideString(filterInterfaceId.c_str(), "filter") - ); - - RETURN_IF_FAILED(initResult); + auto initResult = + proxy->Initialize( + endpointDeviceInterfaceId, + handleDupe.get(), + pinMapEntry->PinId, + requestedBufferSize, + mmCssTaskId, + context, + pinMapEntry->GroupIndex + ); + + if (SUCCEEDED(initResult)) + { + m_midiOutDeviceGroupMap.insert_or_assign(pinMapEntry->GroupIndex, std::move(proxy)); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Unable to initialize Midi Output proxy", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD), + TraceLoggingUInt32(requestedBufferSize, "buffer size"), + TraceLoggingUInt32(pinMapEntry->PinId, "pin id"), + TraceLoggingUInt8(pinMapEntry->GroupIndex, "group"), + TraceLoggingWideString(filterInterfaceId.c_str(), "filter") + ); + + RETURN_IF_FAILED(initResult); + } } } diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index 56e5abbc2..5ba3d3daf 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -44,31 +44,76 @@ CMidi2KSAggregateMidiEndpointManager::Initialize( RETURN_IF_FAILED(midiDeviceManager->QueryInterface(__uuidof(IMidiDeviceManager), (void**)&m_midiDeviceManager)); RETURN_IF_FAILED(midiEndpointProtocolManager->QueryInterface(__uuidof(IMidiEndpointProtocolManager), (void**)&m_midiProtocolManager)); - winrt::hstring parentDeviceSelector( - L"System.Devices.ClassGuid:=\"{4d36e96c-e325-11ce-bfc1-08002be10318}\" AND " \ - L"System.Devices.Present:=System.StructuredQueryType.Boolean#True"); + // needed for internal consumption. Gary to replace this with feature enablement check + // defined in pch.h + if (NewMidiFeatureUpdateKsa2603Enabled()) + { + DWORD individualInterfaceEnumTimeoutMS{ DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; + if (SUCCEEDED(wil::reg::get_value_dword_nothrow(HKEY_LOCAL_MACHINE, MIDI_ROOT_REG_KEY, KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE, &individualInterfaceEnumTimeoutMS))) + { + individualInterfaceEnumTimeoutMS = max(individualInterfaceEnumTimeoutMS, KSA_INTERFACE_ENUM_TIMEOUT_MS_MINIMUM_VALUE); + individualInterfaceEnumTimeoutMS = min(individualInterfaceEnumTimeoutMS, KSA_INTERFACE_ENUM_TIMEOUT_MS_MAXIMUM_VALUE); + + m_individualInterfaceEnumTimeoutMS = individualInterfaceEnumTimeoutMS; + } + else + { + m_individualInterfaceEnumTimeoutMS = DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS; + } + + + // the ksa2603 fix enumerates device interfaces instead of parent devices + + winrt::hstring deviceInterfaceSelector( + L"System.Devices.InterfaceClassGuid:=\"{6994AD04-93EF-11D0-A3CC-00A0C9223196}\" AND " \ + L"System.Devices.InterfaceEnabled: = System.StructuredQueryType.Boolean#True"); + + auto additionalProps = winrt::single_threaded_vector(); + + m_watcher = DeviceInformation::CreateWatcher(deviceInterfaceSelector); + + auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceAdded); + auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceRemoved); + auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceUpdated); + + auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceStopped); + auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnEnumerationCompleted); + + m_DeviceAdded = m_watcher.Added(winrt::auto_revoke, deviceAddedHandler); + m_DeviceRemoved = m_watcher.Removed(winrt::auto_revoke, deviceRemovedHandler); + m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler); + m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler); + m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler); + } + else + { + winrt::hstring parentDeviceSelector( + L"System.Devices.ClassGuid:=\"{4d36e96c-e325-11ce-bfc1-08002be10318}\" AND " \ + L"System.Devices.Present:=System.StructuredQueryType.Boolean#True"); + + // :=System.StructuredQueryType.Boolean#True - // :=System.StructuredQueryType.Boolean#True + auto additionalProps = winrt::single_threaded_vector(); - auto additionalProps = winrt::single_threaded_vector(); + additionalProps.Append(L"System.Devices.DeviceManufacturer"); + additionalProps.Append(L"System.Devices.Manufacturer"); + additionalProps.Append(L"System.Devices.Parent"); - additionalProps.Append(L"System.Devices.DeviceManufacturer"); - additionalProps.Append(L"System.Devices.Manufacturer"); - additionalProps.Append(L"System.Devices.Parent"); + m_watcher = DeviceInformation::CreateWatcher(parentDeviceSelector, additionalProps, DeviceInformationKind::Device); - m_watcher = DeviceInformation::CreateWatcher(parentDeviceSelector, additionalProps, DeviceInformationKind::Device); + auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded); + auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved); + auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceUpdated); + auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceStopped); + auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnEnumerationCompleted); - auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded); - auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved); - auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceUpdated); - auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceStopped); - auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnEnumerationCompleted); + m_DeviceAdded = m_watcher.Added(winrt::auto_revoke, deviceAddedHandler); + m_DeviceRemoved = m_watcher.Removed(winrt::auto_revoke, deviceRemovedHandler); + m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler); + m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler); + m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler); + } - m_DeviceAdded = m_watcher.Added(winrt::auto_revoke, deviceAddedHandler); - m_DeviceRemoved = m_watcher.Removed(winrt::auto_revoke, deviceRemovedHandler); - m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler); - m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler); - m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler); m_watcher.Start(); @@ -736,13 +781,317 @@ ParseParentIdIntoVidPidSerial( } +// Assuming the update/add timeout is 250ms : UPDATE_TIMEOUT. May need to make this a registry setting. + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( + DeviceWatcher watcher, + DeviceInformation filterDevice +) +{ + UNREFERENCED_PARAMETER(watcher); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(filterDevice.Id().c_str(), "added interface") + ); + + + //m_individualInterfaceEnumTimeoutMS + + // Flow for interface ADDED + // - Check to see if this interface has a MIDI 1 pin. If not, bail. + // - Find the parent device (need to go an extra step up. Check KS code to see if similar now) + // - Check to see if we already have the parent device in the pending list. If not, create an entry in the pending list + // - Add the interface to the pending parent device + // - Reset the UPDATE_TIMEOUT timeout for this device (will need timeouts per-device, so maybe worker threads) + // - If no other changes for this device within that UPDATE_TIMEOUT, then + // - Build pin map + // - Build GTBs + // - Build name table + // - Create the device + + + std::wstring transportCode(TRANSPORT_CODE); + std::wstring driverSuppliedName{}; + + + + + // Wrapper opens the handle internally. + KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); + RETURN_IF_FAILED(deviceHandleWrapper.Open()); + + // Using lamba function to prevent handle from dissapearing when being used. + LOG_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetKSDriverSuppliedName(h, driverSuppliedName); + })); + + // enumerate all the pins for this filter + ULONG cPins{ 0 }; + + HRESULT hr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return PinPropertySimple(h, 0, KSPROPSETID_Pin, KSPROPERTY_PIN_CTYPES, &cPins, sizeof(cPins)); + }); + + RETURN_IF_FAILED(hr); + + // process the pins for this filter. Not all will be MIDI pins + for (UINT pinIndex = 0; pinIndex < cPins; pinIndex++) + { + // bool isMidi1Pin{ false }; + + // TODO + std::wstring customPortName{}; + + // Check the communication capabilities of the pin so we can fail fast + KSPIN_COMMUNICATION communication = (KSPIN_COMMUNICATION)0; + + RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return PinPropertySimple(h, pinIndex, KSPROPSETID_Pin, KSPROPERTY_PIN_COMMUNICATION, &communication, sizeof(KSPIN_COMMUNICATION)); + })); + + // The external connector pin representing the physical connection + // has KSPIN_COMMUNICATION_NONE. We can only create on software IO pins which + // have a valid communication. Skip connector pins. + if (communication == KSPIN_COMMUNICATION_NONE) + { + continue; + } + + // Duplicate the handle to safely pass it to another component or store it. + wil::unique_handle handleDupe(deviceHandleWrapper.GetHandle()); + RETURN_IF_NULL_ALLOC(handleDupe); + + // we try to open UMP only so we understand the device + KsHandleWrapper m_PinHandleWrapperUmp(filterDevice.Id().c_str(), pinIndex, MidiTransport_CyclicUMP, handleDupe.get()); + if (SUCCEEDED(m_PinHandleWrapperUmp.Open())) + { + // this is a UMP pin. The KS transport will handle it, so we skip it here. + // In the future, we may want to bail on the first UMP pin we find. + + continue; + } + + // try to open as a MIDI 1 bytestream pin + KsHandleWrapper m_PinHandleWrapperMidi1(filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get()); + if (SUCCEEDED(m_PinHandleWrapperMidi1.Open())) + { + // this is a MIDI 1.0 byte format pin, so let's process it + KsAggregateEndpointMidiPinDefinition pinDefinition{ }; + + //pinDefinition.KSDriverSuppliedName = driverSuppliedName; + pinDefinition.PinNumber = pinIndex; + pinDefinition.FilterDeviceId = std::wstring{ filterDevice.Id() }; + pinDefinition.FilterName = std::wstring{ filterDevice.Name() }; + + // find the name of the pin (supplied by iJack, and if not available, generated by the stack) + std::wstring pinName{ }; + + HRESULT pinNameHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetPinName(h, pinIndex, pinName); + }); + + if (SUCCEEDED(pinNameHr)) + { + pinDefinition.PinName = pinName; + } + + + + // get the data flow so we know if this is a MIDI Input or a MIDI Output + KSPIN_DATAFLOW dataFlow = (KSPIN_DATAFLOW)0; + + HRESULT dataFlowHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetPinDataFlow(h, pinIndex, dataFlow); + }); + + if (SUCCEEDED(dataFlowHr)) + { + if (dataFlow == KSPIN_DATAFLOW_IN) + { + // MIDI Out (input to device) + pinDefinition.PinDataFlow = MidiFlow::MidiFlowIn; + pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowOut; // opposite + + pinDefinition.GroupIndex = static_cast(midiOutputGroupIndexForDevice); + + pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter); + + + + midiOutputPinIndexForThisFilter++; + midiOutputGroupIndexForDevice++; + } + else if (dataFlow == KSPIN_DATAFLOW_OUT) + { + // MIDI In (output from device) + pinDefinition.PinDataFlow = MidiFlow::MidiFlowOut; + pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowIn; // opposite + + pinDefinition.GroupIndex = static_cast(midiInputGroupIndexForDevice); + + pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter); + + midiInputPinIndexForThisFilter++; + midiInputGroupIndexForDevice++; + } + + // This is where we build the proposed names + // ================================================= + + std::wstring customName = L""; + + endpointDefinition.EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( + pinDefinition.GroupIndex, + pinDefinition.DataFlowFromUserPerspective, + customName, + driverSuppliedName, + pinDefinition.FilterName, + pinDefinition.PinName, + pinDefinition.PortIndexWithinThisFilterAndDirection + ); + + endpointDefinition.MidiPins.push_back(pinDefinition); + } + else + { + // this is a failure condition. Move on to next pin + RETURN_IF_FAILED(dataFlowHr); + } + } + } + + + + + + + + + + + + //auto additionalProperties = winrt::single_threaded_vector(); + auto properties = parentDevice.Properties(); + + KsAggregateEndpointDefinition endpointDefinition{ }; + + auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", parentDevice, L""); + RETURN_HR_IF(E_FAIL, deviceInstanceId.empty()); + + + auto systemDevicesParent = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.Parent", parentDevice, L""); + + endpointDefinition.ParentDeviceName = parentDevice.Name(); + endpointDefinition.EndpointName = parentDevice.Name(); + endpointDefinition.ParentDeviceInstanceId = parentDevice.Id(); + + if (!systemDevicesParent.empty()) + { + LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(systemDevicesParent, endpointDefinition)); + } + + // we set this if we find any compatible MIDI 1.0 byte format pins + bool isCompatibleMidi1Device{ false }; + + // enumerate all KS_CATEGORY_AUDIO filters for this parent media device + winrt::hstring filterDeviceSelector( + L"System.Devices.InterfaceClassGuid:=\"" + winrt::hstring(KS_CATEGORY_AUDIO_GUID) + "\""\ + L" AND System.Devices.InterfaceEnabled:= System.StructuredQueryType.Boolean#True"\ + L" AND System.Devices.DeviceInstanceId:= \"" + deviceInstanceId + L"\""); + + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enumerating Filters", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentDevice.Id().c_str(), "parent device id") + ); + + auto filterDevices = DeviceInformation::FindAllAsync(filterDeviceSelector).get(); + + //ULONG midiInputGroupIndexForDevice{ 0 }; + //ULONG midiOutputGroupIndexForDevice{ 0 }; + + //ULONG midiInputPinIndexForThisFilter{ 0 }; + //ULONG midiOutputPinIndexForThisFilter{ 0 }; + + + + // now create the device + + if (isCompatibleMidi1Device && endpointDefinition.MidiPins.size() > 0) + { + // only some vendor drivers provide an actual manufacturer + // and all the in-box drivers just provide the Generic USB Audio string + // TODO: Is "Generic USB Audio" a string that is localized? If so, this + // code will not have the intended effect outside of en-US + auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L""); + if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft") + { + endpointDefinition.ManufacturerName = manufacturer; + } + + // default hash is the device id. We don't have an iSerial here. + std::hash hasher; + std::wstring hash; + hash = std::to_wstring(hasher(endpointDefinition.ParentDeviceInstanceId)); + + endpointDefinition.EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Creating aggregate UMP endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + // We've enumerated all the pins on the device. Now create the aggregated UMP endpoint + RETURN_IF_FAILED(CreateMidiUmpEndpoint(endpointDefinition)); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No compatible MIDI pins. This is normal in most cases.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + } + + + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Filter Enumeration Complete", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentDevice.Id().c_str(), "parent device id") + ); + return S_OK; +} _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceAdded( +CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceRemoved( DeviceWatcher watcher, - DeviceInformation deviceInterface + DeviceInformationUpdate deviceInterfaceUpdate ) { UNREFERENCED_PARAMETER(watcher); @@ -753,15 +1102,48 @@ CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceAdded( TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(deviceInterface.Id().c_str(), "added interface") + TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "added interface") ); + // Flow for interface REMOVED + // - Check for the interface on the existing device + // - Update pin map to remove all entries with that interface + // - If this is the last interface, then remove the device. + // - If not the last interface: + // - Reset the UPDATE_TIMEOUT timeout for this device. If no other removals come through for this device during the timeout: + // - Rebuild pin map, maintaining existing numbers where possible + // - Rebuild GTBs, maintaining existing numbers where possible + // - Recalculate name table + // - call MidiDeviceManager::UpdateEndpointProperties. That will also recalculate MIDI 1 ports + // + + + + return S_OK; +} + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceUpdated( + DeviceWatcher watcher, + DeviceInformationUpdate deviceInterfaceUpdate +) +{ + UNREFERENCED_PARAMETER(watcher); - // Check to see if the interface has pins we want + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "updated interface") + ); - // if this has MIDI 1 pins, check to see if we already have an entry for the device - // create the device if we don't already have it - // get the USB vid/pid from the parent device, if it is a USB device + // Flow for interface UPDATED + // - Check for any name changes + // - If any relevant changes recalculate GTBs and Name table as above and update properties + // @@ -769,6 +1151,9 @@ CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceAdded( } + + + _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded( @@ -1103,9 +1488,6 @@ CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded( } - - - _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved(DeviceWatcher watcher, DeviceInformationUpdate device) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h index c82c2a877..fbc639fe1 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h @@ -14,6 +14,10 @@ using namespace winrt::Windows::Devices::Enumeration; +#define DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS 250 +#define KSA_INTERFACE_ENUM_TIMEOUT_MS_MINIMUM_VALUE 50 +#define KSA_INTERFACE_ENUM_TIMEOUT_MS_MAXIMUM_VALUE 2500 +#define KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE L"KsaInterfaceEnumTimeoutMS" struct KsAggregateEndpointMidiPinDefinition { @@ -78,7 +82,9 @@ class CMidi2KSAggregateMidiEndpointManager : HRESULT OnEnumerationCompleted(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); // new interface-based approach - HRESULT OnDeviceInterfaceAdded(_In_ DeviceWatcher watcher, _In_ DeviceInformation deviceInterface); + HRESULT OnFilterDeviceInterfaceAdded(_In_ DeviceWatcher, _In_ DeviceInformation); + HRESULT OnFilterDeviceInterfaceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); + HRESULT OnFilterDeviceInterfaceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); wil::com_ptr_nothrow m_midiDeviceManager; @@ -97,5 +103,5 @@ class CMidi2KSAggregateMidiEndpointManager : HRESULT GetKSDriverSuppliedName(_In_ HANDLE hFilter, _Inout_ std::wstring& name); - + DWORD m_individualInterfaceEnumTimeoutMS { DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; }; diff --git a/src/api/Transport/KSAggregateTransport/pch.h b/src/api/Transport/KSAggregateTransport/pch.h index 32e034212..b0474480d 100644 --- a/src/api/Transport/KSAggregateTransport/pch.h +++ b/src/api/Transport/KSAggregateTransport/pch.h @@ -112,6 +112,14 @@ namespace internal = ::WindowsMidiServicesInternal; #include "Midi2UMP2BSTransform_i.c" +// this gets replaced with the internal CFR check when pulled into Windows repo +inline bool NewMidiFeatureUpdateKsa2603Enabled() { return true; } + + + + + + class CMidi2KSAggregateMidiEndpointManager; class CMidi2KSAggregateMidiInProxy; class CMidi2KSAggregateMidiOutProxy; From b100127d2fd42d1558c555cfb99e70fcc6c3e1c6 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Sat, 7 Feb 2026 20:50:45 -0500 Subject: [PATCH 03/18] Working on KSA bug 835 --- .../Midi2.KSAggregateMidiEndpointManager.cpp | 694 +++++++++++++----- .../Midi2.KSAggregateMidiEndpointManager.h | 19 + src/api/Transport/KSAggregateTransport/pch.h | 2 + 3 files changed, 531 insertions(+), 184 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index 5ba3d3daf..43a260813 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -72,9 +72,9 @@ CMidi2KSAggregateMidiEndpointManager::Initialize( m_watcher = DeviceInformation::CreateWatcher(deviceInterfaceSelector); - auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceAdded); - auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceRemoved); - auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceInterfaceUpdated); + auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded); + auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceRemoved); + auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceUpdated); auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceStopped); auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnEnumerationCompleted); @@ -84,6 +84,11 @@ CMidi2KSAggregateMidiEndpointManager::Initialize( m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler); m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler); m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler); + + // worker thread to handle endpoint creation, since we're enumerating interfaces now and need to aggregate them + m_endpointCreationThreadWakeup.create(wil::EventOptions::ManualReset); + std::jthread endpointCreationWorkerThread(std::bind_front(&CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker, this)); + m_endpointCreationThread = std::move(endpointCreationWorkerThread); } else { @@ -656,7 +661,18 @@ CMidi2KSAggregateMidiEndpointManager::GetKSDriverSuppliedName(HANDLE hInstantiat &countBytesReturned ); - RETURN_IF_FAILED(hrComponent); + if (NewMidiFeatureUpdateKsa2603Enabled()) + { + // changed to not log the failure here. Failures are expected for many devices, and it's adding noise to error logs + if (FAILED(hrComponent)) + { + return hrComponent; + } + } + else + { + RETURN_IF_FAILED(hrComponent); + } componentId.Name; // this is the GUID which points to the registry location with the driver-supplied name @@ -684,6 +700,7 @@ CMidi2KSAggregateMidiEndpointManager::GetKSDriverSuppliedName(HANDLE hInstantiat #define KS_CATEGORY_AUDIO_GUID L"{6994AD04-93EF-11D0-A3CC-00A0C9223196}" + HRESULT ParseParentIdIntoVidPidSerial( _In_ winrt::hstring systemDevicesParentValue, @@ -781,16 +798,102 @@ ParseParentIdIntoVidPidSerial( } -// Assuming the update/add timeout is 250ms : UPDATE_TIMEOUT. May need to make this a registry setting. - _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( - DeviceWatcher watcher, - DeviceInformation filterDevice +CMidi2KSAggregateMidiEndpointManager::FindOrCreateMasterEndpointDefinitionForFilterDevice( + DeviceInformation filterDevice, + std::shared_ptr& endpointDefinition ) { - UNREFERENCED_PARAMETER(watcher); + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + // we require that the System.Devices.DeviceInstanceId property was requested for the passed-in filter device + auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); + RETURN_HR_IF(E_FAIL, deviceInstanceId.empty()); + + auto additionalProperties = winrt::single_threaded_vector(); + additionalProperties.Append(L"System.Devices.DeviceManufacturer"); + additionalProperties.Append(L"System.Devices.Manufacturer"); + additionalProperties.Append(L"System.Devices.Parent"); + + auto parentDevice = DeviceInformation::CreateFromIdAsync(deviceInstanceId, + additionalProperties, winrt::Windows::Devices::Enumeration::DeviceInformationKind::Device).get(); + + // See if we already have a pending master endpoint definition for this parent device + + auto lock = m_pendingEndpointDefinitionsLock.lock(); // we lock to avoid having one inserted while we're processing + + auto parentInstanceIdToFind = internal::NormalizeDeviceInstanceIdWStringCopy(deviceInstanceId.c_str()); + auto it = std::find_if( + m_pendingEndpointDefinitions.begin(), + m_pendingEndpointDefinitions.end(), + [&parentInstanceIdToFind](const std::shared_ptr def){return def->ParentDeviceInstanceId == parentInstanceIdToFind; }); + + if (it != m_pendingEndpointDefinitions.end()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found existing aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentInstanceIdToFind.c_str(), "parent") + ); + + endpointDefinition = *it; + return S_OK; + } + + // We don't have one, create one and add, and get all the parent device information for it + // we still have the map locked, so keep this code fast + auto newEndpointDefinition = std::make_shared(); + RETURN_HR_IF_NULL(E_OUTOFMEMORY, newEndpointDefinition); + +// auto systemDevicesParent = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.Parent", parentDevice, L""); + + newEndpointDefinition->ParentDeviceName = parentDevice.Name(); + newEndpointDefinition->EndpointName = parentDevice.Name(); + newEndpointDefinition->ParentDeviceInstanceId = parentDevice.Id(); + + LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(newEndpointDefinition->ParentDeviceInstanceId.c_str(), *newEndpointDefinition)); + + // TEMP. Remove before publishing + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Creating new aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(newEndpointDefinition->ParentDeviceInstanceId.c_str(), "parent"), + TraceLoggingUInt16(newEndpointDefinition->VID, "VID"), + TraceLoggingUInt16(newEndpointDefinition->PID, "PID") + ); + + // only some vendor drivers provide an actual manufacturer + // and all the in-box drivers just provide the Generic USB Audio string + // TODO: Is "Generic USB Audio" a string that is localized? If so, this + // code will not have the intended effect outside of en-US + auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L""); + if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft") + { + newEndpointDefinition->ManufacturerName = manufacturer; + } + + // default hash is the device id. + std::hash hasher; + std::wstring hash; + hash = std::to_wstring(hasher(newEndpointDefinition->ParentDeviceInstanceId)); + + newEndpointDefinition->EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -798,55 +901,243 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(filterDevice.Id().c_str(), "added interface") + TraceLoggingWideString(L"Adding pending aggregate UMP endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD) ); + m_pendingEndpointDefinitions.push_back(newEndpointDefinition); + endpointDefinition = newEndpointDefinition; - //m_individualInterfaceEnumTimeoutMS - // Flow for interface ADDED - // - Check to see if this interface has a MIDI 1 pin. If not, bail. - // - Find the parent device (need to go an extra step up. Check KS code to see if similar now) - // - Check to see if we already have the parent device in the pending list. If not, create an entry in the pending list - // - Add the interface to the pending parent device - // - Reset the UPDATE_TIMEOUT timeout for this device (will need timeouts per-device, so maybe worker threads) - // - If no other changes for this device within that UPDATE_TIMEOUT, then - // - Build pin map - // - Build GTBs - // - Build name table - // - Create the device + return S_OK; +} +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::GetNextGroupIndex( + std::shared_ptr definition, + MidiFlow dataFlowFromUserPerspective, + uint8_t& groupIndex) +{ + // iterate through the pin definitions which match this data flow + // return last + 1 for the group index for this endpoint + // we could cache this value in the structure, but trying not to + // change the structs with this CFR bugfix - std::wstring transportCode(TRANSPORT_CODE); - std::wstring driverSuppliedName{}; + uint8_t nextGroupIndex{ 0 }; + for (auto const& pin : definition->MidiPins) + { + if (pin.DataFlowFromUserPerspective == dataFlowFromUserPerspective) + { + nextGroupIndex = max(pin.GroupIndex + 1, nextGroupIndex); + } + } + + groupIndex = nextGroupIndex; + + return S_OK; +} + + +_Use_decl_annotations_ +void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( + std::stop_token token) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + while (!token.stop_requested()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Waiting to be woken up.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + // wait to be woken up + m_endpointCreationThreadWakeup.wait(); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: I'm awake now.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + if (!token.stop_requested()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Short nap before checking.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingUInt32(m_individualInterfaceEnumTimeoutMS, "nap period (ms)") + ); + + // we sleep for this timeout before we check to see if the thread is signaled. + // this gives time for an additional interface notification to cause the event + // to be reset + Sleep(m_individualInterfaceEnumTimeoutMS); + + // if we're still signaled, that means no other pnp notifications came in during the nap, or if they + // did, they completed within that nap period + if (m_endpointCreationThreadWakeup.is_signaled() && m_pendingEndpointDefinitions.size() > 0) + { + m_endpointCreationThreadWakeup.ResetEvent(); // it's a manual reset event + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Thread was signaled and pending definition count > 0. Proceed to processing queue.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + + // lock the definitions so we can process them + auto lock = m_pendingEndpointDefinitionsLock.lock(); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Processing pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + while (m_pendingEndpointDefinitions.size() > 0) + { + // effectively a queue, but because we have to iterate and search + // this in other functions, a vector is more appropriate + auto ep = m_pendingEndpointDefinitions[0]; + m_pendingEndpointDefinitions.erase(m_pendingEndpointDefinitions.begin()); + + // create the endpoint + LOG_IF_FAILED(CreateMidiUmpEndpoint(*ep)); + } + } + else + { + if (m_pendingEndpointDefinitions.size() == 0) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Woken up, but no work to do. Pending count == 0.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Woken up, but thread is no longer signaled", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + } + } + + + } + } + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Exit.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); +} + +_Use_decl_annotations_ +bool CMidi2KSAggregateMidiEndpointManager::KSAEndpointForDeviceExists( + _In_ std::wstring deviceInstanceId) +{ + for (auto const& entry : m_availableEndpointDefinitions) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second.EndpointDeviceInstanceId) == + internal::NormalizeDeviceInstanceIdWStringCopy(deviceInstanceId.c_str())) + { + return true; + } + } + + return false; +} + +// TODO: If this is a new filter for an existing device, we need to update properties on that +// device, not recreate the endpoint + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( + DeviceWatcher /* watcher */, + DeviceInformation filterDevice +) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "added interface") + ); + std::wstring transportCode(TRANSPORT_CODE); // Wrapper opens the handle internally. KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); RETURN_IF_FAILED(deviceHandleWrapper.Open()); - // Using lamba function to prevent handle from dissapearing when being used. - LOG_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetKSDriverSuppliedName(h, driverSuppliedName); - })); + std::shared_ptr endpointDefinition{ nullptr }; + + // ============================================================================================= + // Go through all the enumerated pins, looking for a MIDI 1.0 pin // enumerate all the pins for this filter ULONG cPins{ 0 }; - HRESULT hr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { return PinPropertySimple(h, 0, KSPROPSETID_Pin, KSPROPERTY_PIN_CTYPES, &cPins, sizeof(cPins)); - }); + })); - RETURN_IF_FAILED(hr); + std::wstring driverSuppliedName{}; + ULONG midiInputPinIndexForThisFilter{ 0 }; + ULONG midiOutputPinIndexForThisFilter{ 0 }; + + bool checkedForDriverSuppliedName{ false }; // process the pins for this filter. Not all will be MIDI pins for (UINT pinIndex = 0; pinIndex < cPins; pinIndex++) { // bool isMidi1Pin{ false }; - // TODO + // TODO std::wstring customPortName{}; // Check the communication capabilities of the pin so we can fail fast @@ -869,104 +1160,79 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( RETURN_IF_NULL_ALLOC(handleDupe); // we try to open UMP only so we understand the device + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Checking for UMP pin. This will fallback error fail for non-UMP devices.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + KsHandleWrapper m_PinHandleWrapperUmp(filterDevice.Id().c_str(), pinIndex, MidiTransport_CyclicUMP, handleDupe.get()); if (SUCCEEDED(m_PinHandleWrapperUmp.Open())) { // this is a UMP pin. The KS transport will handle it, so we skip it here. // In the future, we may want to bail on the first UMP pin we find. + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found UMP/MIDI2 pin. Skipping for this transport.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + continue; } + // try to open as a MIDI 1 bytestream pin + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Checking for MIDI 1 pin. This will fallback error fail for non-MIDI devices.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + KsHandleWrapper m_PinHandleWrapperMidi1(filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get()); if (SUCCEEDED(m_PinHandleWrapperMidi1.Open())) { - // this is a MIDI 1.0 byte format pin, so let's process it - KsAggregateEndpointMidiPinDefinition pinDefinition{ }; - - //pinDefinition.KSDriverSuppliedName = driverSuppliedName; - pinDefinition.PinNumber = pinIndex; - pinDefinition.FilterDeviceId = std::wstring{ filterDevice.Id() }; - pinDefinition.FilterName = std::wstring{ filterDevice.Name() }; - - // find the name of the pin (supplied by iJack, and if not available, generated by the stack) - std::wstring pinName{ }; + // don't wakeup right now. We've found a midi1 pin + m_endpointCreationThreadWakeup.ResetEvent(); - HRESULT pinNameHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetPinName(h, pinIndex, pinName); - }); - if (SUCCEEDED(pinNameHr)) - { - pinDefinition.PinName = pinName; - } - // get the data flow so we know if this is a MIDI Input or a MIDI Output - KSPIN_DATAFLOW dataFlow = (KSPIN_DATAFLOW)0; - HRESULT dataFlowHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetPinDataFlow(h, pinIndex, dataFlow); - }); + // check to see if this filter is for an endpoint device we've already created. + // if it is, we have to take a different approach to updating it. We don't want + // to just tear down and rebuild the current device, because that churns ports + // and can cause disconnections. - if (SUCCEEDED(dataFlowHr)) + auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); + RETURN_HR_IF(E_FAIL, deviceInstanceId.empty()); + if (KSAEndpointForDeviceExists(deviceInstanceId.c_str())) { - if (dataFlow == KSPIN_DATAFLOW_IN) - { - // MIDI Out (input to device) - pinDefinition.PinDataFlow = MidiFlow::MidiFlowIn; - pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowOut; // opposite - - pinDefinition.GroupIndex = static_cast(midiOutputGroupIndexForDevice); - - pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter); - - - - midiOutputPinIndexForThisFilter++; - midiOutputGroupIndexForDevice++; - } - else if (dataFlow == KSPIN_DATAFLOW_OUT) - { - // MIDI In (output from device) - pinDefinition.PinDataFlow = MidiFlow::MidiFlowOut; - pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowIn; // opposite - - pinDefinition.GroupIndex = static_cast(midiInputGroupIndexForDevice); - - pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter); - - midiInputPinIndexForThisFilter++; - midiInputGroupIndexForDevice++; - } - - // This is where we build the proposed names - // ================================================= - - std::wstring customName = L""; - - endpointDefinition.EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( - pinDefinition.GroupIndex, - pinDefinition.DataFlowFromUserPerspective, - customName, - driverSuppliedName, - pinDefinition.FilterName, - pinDefinition.PinName, - pinDefinition.PortIndexWithinThisFilterAndDirection + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"KSA endpoint for this filter already activated. TEMP skipping.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") ); - endpointDefinition.MidiPins.push_back(pinDefinition); + return S_OK; } - else - { - // this is a failure condition. Move on to next pin - RETURN_IF_FAILED(dataFlowHr); - } - } - } - @@ -977,112 +1243,166 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( - //auto additionalProperties = winrt::single_threaded_vector(); - auto properties = parentDevice.Properties(); - - KsAggregateEndpointDefinition endpointDefinition{ }; - - auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", parentDevice, L""); - RETURN_HR_IF(E_FAIL, deviceInstanceId.empty()); + if (endpointDefinition == nullptr) + { + // first MIDI 1 pin we're processing for this interface + RETURN_IF_FAILED(FindOrCreateMasterEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); + RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); + } - auto systemDevicesParent = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.Parent", parentDevice, L""); - - endpointDefinition.ParentDeviceName = parentDevice.Name(); - endpointDefinition.EndpointName = parentDevice.Name(); - endpointDefinition.ParentDeviceInstanceId = parentDevice.Id(); - - if (!systemDevicesParent.empty()) - { - LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(systemDevicesParent, endpointDefinition)); - } - // we set this if we find any compatible MIDI 1.0 byte format pins - bool isCompatibleMidi1Device{ false }; + if (driverSuppliedName.empty() && !checkedForDriverSuppliedName) + { + // Using lamba function to prevent handle from dissapearing when being used. + // we don't log the HRESULT because this is not critical and will often fail + deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetKSDriverSuppliedName(h, driverSuppliedName); + }); - // enumerate all KS_CATEGORY_AUDIO filters for this parent media device - winrt::hstring filterDeviceSelector( - L"System.Devices.InterfaceClassGuid:=\"" + winrt::hstring(KS_CATEGORY_AUDIO_GUID) + "\""\ - L" AND System.Devices.InterfaceEnabled:= System.StructuredQueryType.Boolean#True"\ - L" AND System.Devices.DeviceInstanceId:= \"" + deviceInstanceId + L"\""); + checkedForDriverSuppliedName = true; + if (driverSuppliedName.empty()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No driver-supplied name", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Driver-supplied name found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(driverSuppliedName.c_str(), "driver-supplied name") + ); + } + } - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enumerating Filters", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(parentDevice.Id().c_str(), "parent device id") - ); + // this is a MIDI 1.0 byte format pin, so let's process it + KsAggregateEndpointMidiPinDefinition pinDefinition{ }; - auto filterDevices = DeviceInformation::FindAllAsync(filterDeviceSelector).get(); + pinDefinition.PinNumber = pinIndex; + pinDefinition.FilterDeviceId = std::wstring{ filterDevice.Id() }; + pinDefinition.FilterName = std::wstring{ filterDevice.Name() }; - //ULONG midiInputGroupIndexForDevice{ 0 }; - //ULONG midiOutputGroupIndexForDevice{ 0 }; + // find the name of the pin (supplied by iJack, and if not available, generated by the stack) + std::wstring pinName{ }; + HRESULT pinNameHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetPinName(h, pinIndex, pinName); + }); - //ULONG midiInputPinIndexForThisFilter{ 0 }; - //ULONG midiOutputPinIndexForThisFilter{ 0 }; + if (SUCCEEDED(pinNameHr)) + { + pinDefinition.PinName = pinName; + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Pin has name", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(pinDefinition.PinName.c_str(), "pin name") + ); + } + // get the data flow so we know if this is a MIDI Input (Source) or a MIDI Output (Destination) + KSPIN_DATAFLOW dataFlow = (KSPIN_DATAFLOW)0; + RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetPinDataFlow(h, pinIndex, dataFlow); + })); - // now create the device + if (dataFlow == KSPIN_DATAFLOW_IN) + { + // MIDI Out (input to device) + pinDefinition.PinDataFlow = MidiFlow::MidiFlowIn; + pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowOut; // opposite + pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter); - if (isCompatibleMidi1Device && endpointDefinition.MidiPins.size() > 0) - { - // only some vendor drivers provide an actual manufacturer - // and all the in-box drivers just provide the Generic USB Audio string - // TODO: Is "Generic USB Audio" a string that is localized? If so, this - // code will not have the intended effect outside of en-US - auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L""); - if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft") + midiOutputPinIndexForThisFilter++; + } + else if (dataFlow == KSPIN_DATAFLOW_OUT) { - endpointDefinition.ManufacturerName = manufacturer; + // MIDI In (output from device) + pinDefinition.PinDataFlow = MidiFlow::MidiFlowOut; + pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowIn; // opposite + pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter); + + midiInputPinIndexForThisFilter++; } - // default hash is the device id. We don't have an iSerial here. - std::hash hasher; - std::wstring hash; - hash = std::to_wstring(hasher(endpointDefinition.ParentDeviceInstanceId)); + // not being able to get the group index is fatal + RETURN_IF_FAILED(GetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); - endpointDefinition.EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; + // This is where we build the proposed names + // ================================================= - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Creating aggregate UMP endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + std::wstring customName = L""; // TODO + + endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( + pinDefinition.GroupIndex, + pinDefinition.DataFlowFromUserPerspective, + customName, + driverSuppliedName, + pinDefinition.FilterName, + pinDefinition.PinName, + pinDefinition.PortIndexWithinThisFilterAndDirection ); - // We've enumerated all the pins on the device. Now create the aggregated UMP endpoint - RETURN_IF_FAILED(CreateMidiUmpEndpoint(endpointDefinition)); - } - else - { + endpointDefinition->MidiPins.push_back(pinDefinition); + TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, + MIDI_TRACE_EVENT_VERBOSE, TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"No compatible MIDI pins. This is normal in most cases.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + TraceLoggingWideString(L"MIDI 1.0 pin added", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") ); } - + } - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Filter Enumeration Complete", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(parentDevice.Id().c_str(), "parent device id") - ); + if (endpointDefinition == nullptr || endpointDefinition->MidiPins.size() == 0) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No MIDI 1.0 pins found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + } + else if (endpointDefinition->MidiPins.size() > 0) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Filter pin Enumeration Complete", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingUInt32(static_cast(endpointDefinition->MidiPins.size()), "total MIDI 1.0 pin count") + ); + + m_endpointCreationThreadWakeup.SetEvent(); + } return S_OK; } @@ -1629,6 +1949,12 @@ CMidi2KSAggregateMidiEndpointManager::Shutdown() TraceLoggingPointer(this, "this") ); + if (NewMidiFeatureUpdateKsa2603Enabled()) + { + m_endpointCreationThread.request_stop(); + m_endpointCreationThreadWakeup.SetEvent(); + } + m_DeviceAdded.revoke(); m_DeviceRemoved.revoke(); m_DeviceUpdated.revoke(); diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h index fbc639fe1..25fbfaee2 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h @@ -93,6 +93,25 @@ class CMidi2KSAggregateMidiEndpointManager : wil::critical_section m_availableEndpointDefinitionsLock; std::map m_availableEndpointDefinitions; + wil::critical_section m_pendingEndpointDefinitionsLock; + std::vector> m_pendingEndpointDefinitions; + HRESULT FindOrCreateMasterEndpointDefinitionForFilterDevice( + _In_ DeviceInformation, + _In_ std::shared_ptr&); + + bool KSAEndpointForDeviceExists( + _In_ std::wstring deviceInstanceId); + + HRESULT GetNextGroupIndex( + _In_ std::shared_ptr definition, + _In_ MidiFlow dataFlowFromUserPerspective, + _In_ uint8_t& groupIndex); + + wil::unique_event_nothrow m_endpointCreationThreadWakeup; + std::jthread m_endpointCreationThread; + void EndpointCreationThreadWorker(_In_ std::stop_token token); + + DeviceWatcher m_watcher{0}; winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Added_revoker m_DeviceAdded; winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Removed_revoker m_DeviceRemoved; diff --git a/src/api/Transport/KSAggregateTransport/pch.h b/src/api/Transport/KSAggregateTransport/pch.h index b0474480d..f34d0fb1e 100644 --- a/src/api/Transport/KSAggregateTransport/pch.h +++ b/src/api/Transport/KSAggregateTransport/pch.h @@ -36,6 +36,8 @@ #include #include +#include +#include #define _ATL_APARTMENT_THREADED From d50b381c5726291c9f0017a8b6f7a6d576383650 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Sun, 8 Feb 2026 01:49:57 -0500 Subject: [PATCH 04/18] Initial enumeration for KSA back up and running with new structure --- .nuke/build.schema.json | 16 +--- build/staging/version/BundleInfo.wxi | 6 +- .../version/WindowsMidiServicesVersion.cs | 16 ++-- .../version/WindowsMidiServicesVersion.h | 14 ++-- .../Midi2.KSAggregateMidiBidi.cpp | 1 + .../Midi2.KSAggregateMidiEndpointManager.cpp | 78 +++++++++++++------ .../Midi2.KSAggregateMidiInProxy.cpp | 1 + 7 files changed, 80 insertions(+), 52 deletions(-) diff --git a/.nuke/build.schema.json b/.nuke/build.schema.json index 55ab3be36..d08d610e7 100644 --- a/.nuke/build.schema.json +++ b/.nuke/build.schema.json @@ -25,21 +25,13 @@ "type": "string", "enum": [ "BuildAndPublishAll", - "T_BuildAndPackageElectronProjection", - "T_BuildAndPackAllAppSDKs", - "T_BuildAppSdkRuntimeAndToolsInstaller", - "T_BuildAppSDKToolsAndTests", - "T_BuildConsoleApp", - "T_BuildCppSamples", - "T_BuildCSharpSamples", - "T_BuildPowerShellProjection", - "T_BuildSettingsApp", - "T_BuildUserToolsSharedComponents", - "T_CopySharedDesignAssets", + "T_BuildServiceAndPlugins", + "T_BuildServiceAndPluginsInstaller", "T_CreateVersionIncludes", "T_Prerequisites", "T_ZipPowershellDevUtilities", - "T_ZipSamples" + "T_ZipServicePdbs", + "T_ZipWdmaud2" ] }, "Verbosity": { diff --git a/build/staging/version/BundleInfo.wxi b/build/staging/version/BundleInfo.wxi index 73b94e473..c352c6d40 100644 --- a/build/staging/version/BundleInfo.wxi +++ b/build/staging/version/BundleInfo.wxi @@ -1,5 +1,5 @@ - - - + + + diff --git a/build/staging/version/WindowsMidiServicesVersion.cs b/build/staging/version/WindowsMidiServicesVersion.cs index 7522e02a8..c211cdaf8 100644 --- a/build/staging/version/WindowsMidiServicesVersion.cs +++ b/build/staging/version/WindowsMidiServicesVersion.cs @@ -9,15 +9,15 @@ public static class MidiNuGetBuildInformation { public const bool IsPreview = true; public const string Source = "GitHub Preview"; - public const string BuildDate = "2026-01-18"; - public const string Name = "SDK Release Candidate 1"; - public const string BuildFullVersion = "1.0.14-rc.1.213"; + public const string BuildDate = "2026-02-07"; + public const string Name = "Service Preview 14"; + public const string BuildFullVersion = "1.0.15-preview.14.74"; public const ushort VersionMajor = 1; public const ushort VersionMinor = 0; - public const ushort VersionPatch = 14; - public const ushort VersionBuildNumber = 213; - public const string Preview = "rc.1.213"; - public const string AssemblyFullVersion = "1.0.14.213"; - public const string FileFullVersion = "1.0.14.213"; + public const ushort VersionPatch = 15; + public const ushort VersionBuildNumber = 74; + public const string Preview = "preview.14.74"; + public const string AssemblyFullVersion = "1.0.15.74"; + public const string FileFullVersion = "1.0.15.74"; } } diff --git a/build/staging/version/WindowsMidiServicesVersion.h b/build/staging/version/WindowsMidiServicesVersion.h index 498434dd4..fb9913615 100644 --- a/build/staging/version/WindowsMidiServicesVersion.h +++ b/build/staging/version/WindowsMidiServicesVersion.h @@ -7,14 +7,14 @@ #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_IS_PREVIEW true #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_SOURCE L"GitHub Preview" -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_DATE L"2026-01-18" -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_NAME L"SDK Release Candidate 1" -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FULL L"1.0.14-rc.1.213" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_DATE L"2026-02-07" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_NAME L"Service Preview 14" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FULL L"1.0.15-preview.14.74" #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_MAJOR 1 #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_MINOR 0 -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_PATCH 14 -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_BUILD_NUMBER 213 -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_PREVIEW L"rc.1.213" -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FILE L"1.0.14.213" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_PATCH 15 +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_BUILD_NUMBER 74 +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_PREVIEW L"preview.14.74" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FILE L"1.0.15.74" #endif diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiBidi.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiBidi.cpp index 3e5526605..64692c741 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiBidi.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiBidi.cpp @@ -35,6 +35,7 @@ CMidi2KSAggregateMidiBidi::Initialize( TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), + TraceLoggingPointer(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), TraceLoggingWideString(device, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) ); diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index 43a260813..96db61352 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -69,6 +69,7 @@ CMidi2KSAggregateMidiEndpointManager::Initialize( L"System.Devices.InterfaceEnabled: = System.StructuredQueryType.Boolean#True"); auto additionalProps = winrt::single_threaded_vector(); + additionalProps.Append(L"System.Devices.Parent"); m_watcher = DeviceInformation::CreateWatcher(deviceInterfaceSelector); @@ -125,6 +126,22 @@ CMidi2KSAggregateMidiEndpointManager::Initialize( // Wait for everything to be created so that they're available immediately after service start. m_EnumerationCompleted.wait(INITIAL_ENUMERATION_TIMEOUT_MS); + if (NewMidiFeatureUpdateKsa2603Enabled()) + { + if (m_pendingEndpointDefinitions.size() > 0) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enumeration completed with endpoint definitions left pending.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingUInt32(static_cast(m_pendingEndpointDefinitions.size()), "count pending definitions") + ); + } + } + return S_OK; } @@ -830,11 +847,11 @@ CMidi2KSAggregateMidiEndpointManager::FindOrCreateMasterEndpointDefinitionForFil auto lock = m_pendingEndpointDefinitionsLock.lock(); // we lock to avoid having one inserted while we're processing - auto parentInstanceIdToFind = internal::NormalizeDeviceInstanceIdWStringCopy(deviceInstanceId.c_str()); + auto parentInstanceIdToFind = internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice.Id().c_str()); auto it = std::find_if( m_pendingEndpointDefinitions.begin(), m_pendingEndpointDefinitions.end(), - [&parentInstanceIdToFind](const std::shared_ptr def){return def->ParentDeviceInstanceId == parentInstanceIdToFind; }); + [&parentInstanceIdToFind](const std::shared_ptr def){return internal::NormalizeDeviceInstanceIdWStringCopy(def->ParentDeviceInstanceId) == parentInstanceIdToFind; }); if (it != m_pendingEndpointDefinitions.end()) { @@ -865,7 +882,6 @@ CMidi2KSAggregateMidiEndpointManager::FindOrCreateMasterEndpointDefinitionForFil LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(newEndpointDefinition->ParentDeviceInstanceId.c_str(), *newEndpointDefinition)); - // TEMP. Remove before publishing TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), MIDI_TRACE_EVENT_VERBOSE, @@ -873,9 +889,7 @@ CMidi2KSAggregateMidiEndpointManager::FindOrCreateMasterEndpointDefinitionForFil TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Creating new aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(newEndpointDefinition->ParentDeviceInstanceId.c_str(), "parent"), - TraceLoggingUInt16(newEndpointDefinition->VID, "VID"), - TraceLoggingUInt16(newEndpointDefinition->PID, "PID") + TraceLoggingWideString(newEndpointDefinition->ParentDeviceInstanceId.c_str(), "parent") ); // only some vendor drivers provide an actual manufacturer @@ -1012,7 +1026,7 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, + MIDI_TRACE_EVENT_VERBOSE, TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), @@ -1029,6 +1043,15 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( // create the endpoint LOG_IF_FAILED(CreateMidiUmpEndpoint(*ep)); } + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Processed all pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); } else { @@ -1055,8 +1078,6 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( ); } } - - } } @@ -1066,7 +1087,7 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Exit.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + TraceLoggingWideString(L"Exit", MIDI_TRACE_EVENT_MESSAGE_FIELD) ); @@ -1074,12 +1095,12 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( _Use_decl_annotations_ bool CMidi2KSAggregateMidiEndpointManager::KSAEndpointForDeviceExists( - _In_ std::wstring deviceInstanceId) + _In_ std::wstring parentDeviceInstanceId) { for (auto const& entry : m_availableEndpointDefinitions) { - if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second.EndpointDeviceInstanceId) == - internal::NormalizeDeviceInstanceIdWStringCopy(deviceInstanceId.c_str())) + if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second.ParentDeviceInstanceId) == + internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) { return true; } @@ -1212,14 +1233,17 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( + // TODO ============================================================== + // check to see if this filter is for an endpoint device we've already created. // if it is, we have to take a different approach to updating it. We don't want // to just tear down and rebuild the current device, because that churns ports // and can cause disconnections. - auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); - RETURN_HR_IF(E_FAIL, deviceInstanceId.empty()); - if (KSAEndpointForDeviceExists(deviceInstanceId.c_str())) + //auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); + auto parentInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); + RETURN_HR_IF(E_FAIL, parentInstanceId.empty()); + if (KSAEndpointForDeviceExists(parentInstanceId.c_str())) { TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1228,17 +1252,14 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"KSA endpoint for this filter already activated. TEMP skipping.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") ); return S_OK; } - - - - - + // END TODO ============================================================== @@ -1346,6 +1367,19 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( // not being able to get the group index is fatal RETURN_IF_FAILED(GetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingUInt8(pinDefinition.GroupIndex, "group index"), + TraceLoggingWideString(pinDefinition.DataFlowFromUserPerspective == + MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") + ); + // This is where we build the proposed names // ================================================= diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp index a903ec2cb..bfec5a708 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp @@ -30,6 +30,7 @@ CMidi2KSAggregateMidiInProxy::Initialize( TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), TraceLoggingWideString(endpointDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD), TraceLoggingUInt32(pinId, "Pin id"), TraceLoggingUInt8(groupIndex, "Group index") From 0713db46ccf47e59654e309b2232c628be88786e Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Sun, 8 Feb 2026 13:13:26 -0500 Subject: [PATCH 05/18] Fix incorrect debug mode check in KSA Midi In Proxy --- .../KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp index bfec5a708..4862ac3d3 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiInProxy.cpp @@ -118,7 +118,7 @@ CMidi2KSAggregateMidiInProxy::Callback( RETURN_HR_IF_NULL(E_POINTER, m_callback); RETURN_HR_IF_NULL(E_POINTER, m_bs2UmpTransform); -#ifndef _DEBUG +#ifdef _DEBUG TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), MIDI_TRACE_EVENT_VERBOSE, From dab1647a34949ce2a6a413cf78281fef2cf8f8a3 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Sun, 8 Feb 2026 15:05:35 -0500 Subject: [PATCH 06/18] Almost there If you start up loopMIDI when the service is already running, all your existing ports will appear. Still working on support for ports added or removed after loopMIDI has already started. --- .../Midi2.KSAggregateMidiEndpointManager.cpp | 344 +++++++++--------- .../Midi2.KSAggregateMidiEndpointManager.h | 23 +- 2 files changed, 192 insertions(+), 175 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index 96db61352..c39fe2a98 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -927,30 +927,26 @@ CMidi2KSAggregateMidiEndpointManager::FindOrCreateMasterEndpointDefinitionForFil _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::GetNextGroupIndex( +CMidi2KSAggregateMidiEndpointManager::IncrementAndGetNextGroupIndex( std::shared_ptr definition, MidiFlow dataFlowFromUserPerspective, uint8_t& groupIndex) { - // iterate through the pin definitions which match this data flow - // return last + 1 for the group index for this endpoint - // we could cache this value in the structure, but trying not to - // change the structs with this CFR bugfix - - uint8_t nextGroupIndex{ 0 }; - for (auto const& pin : definition->MidiPins) + if (dataFlowFromUserPerspective == MidiFlow::MidiFlowIn) { - if (pin.DataFlowFromUserPerspective == dataFlowFromUserPerspective) - { - nextGroupIndex = max(pin.GroupIndex + 1, nextGroupIndex); - } + definition->CurrentHighestMidiSourceGroupIndex++; + groupIndex = definition->CurrentHighestMidiSourceGroupIndex; + } + else + { + definition->CurrentHighestMidiDestinationGroupIndex++; + groupIndex = definition->CurrentHighestMidiDestinationGroupIndex; } - - groupIndex = nextGroupIndex; return S_OK; } +#define MAX_THREAD_WORKER_WAIT_TIME_MS 20000 _Use_decl_annotations_ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( @@ -977,7 +973,7 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( ); // wait to be woken up - m_endpointCreationThreadWakeup.wait(); + m_endpointCreationThreadWakeup.wait(MAX_THREAD_WORKER_WAIT_TIME_MS); TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1053,6 +1049,8 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( TraceLoggingWideString(L"EndpointCreationWorker: Processed all pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD) ); } + +#ifdef _DEBUG else { if (m_pendingEndpointDefinitions.size() == 0) @@ -1078,6 +1076,7 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( ); } } +#endif } } @@ -1109,34 +1108,18 @@ bool CMidi2KSAggregateMidiEndpointManager::KSAEndpointForDeviceExists( return false; } -// TODO: If this is a new filter for an existing device, we need to update properties on that -// device, not recreate the endpoint _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( - DeviceWatcher /* watcher */, - DeviceInformation filterDevice +CMidi2KSAggregateMidiEndpointManager::GetMidi1FilterPins( + DeviceInformation filterDevice, + std::vector& pinListToAddTo ) { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "added interface") - ); - - std::wstring transportCode(TRANSPORT_CODE); - // Wrapper opens the handle internally. KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); RETURN_IF_FAILED(deviceHandleWrapper.Open()); - std::shared_ptr endpointDefinition{ nullptr }; - // ============================================================================================= // Go through all the enumerated pins, looking for a MIDI 1.0 pin @@ -1147,20 +1130,12 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( return PinPropertySimple(h, 0, KSPROPSETID_Pin, KSPROPERTY_PIN_CTYPES, &cPins, sizeof(cPins)); })); - std::wstring driverSuppliedName{}; ULONG midiInputPinIndexForThisFilter{ 0 }; ULONG midiOutputPinIndexForThisFilter{ 0 }; - bool checkedForDriverSuppliedName{ false }; - // process the pins for this filter. Not all will be MIDI pins for (UINT pinIndex = 0; pinIndex < cPins; pinIndex++) { - // bool isMidi1Pin{ false }; - - // TODO - std::wstring customPortName{}; - // Check the communication capabilities of the pin so we can fail fast KSPIN_COMMUNICATION communication = (KSPIN_COMMUNICATION)0; @@ -1222,94 +1197,9 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") ); - KsHandleWrapper m_PinHandleWrapperMidi1(filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get()); - if (SUCCEEDED(m_PinHandleWrapperMidi1.Open())) + KsHandleWrapper pinHandleWrapperMidi1(filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get()); + if (SUCCEEDED(pinHandleWrapperMidi1.Open())) { - // don't wakeup right now. We've found a midi1 pin - m_endpointCreationThreadWakeup.ResetEvent(); - - - - - - - // TODO ============================================================== - - // check to see if this filter is for an endpoint device we've already created. - // if it is, we have to take a different approach to updating it. We don't want - // to just tear down and rebuild the current device, because that churns ports - // and can cause disconnections. - - //auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); - auto parentInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); - RETURN_HR_IF(E_FAIL, parentInstanceId.empty()); - if (KSAEndpointForDeviceExists(parentInstanceId.c_str())) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"KSA endpoint for this filter already activated. TEMP skipping.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") - ); - - return S_OK; - } - - // END TODO ============================================================== - - - - - - if (endpointDefinition == nullptr) - { - // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindOrCreateMasterEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); - RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); - } - - - if (driverSuppliedName.empty() && !checkedForDriverSuppliedName) - { - // Using lamba function to prevent handle from dissapearing when being used. - // we don't log the HRESULT because this is not critical and will often fail - deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetKSDriverSuppliedName(h, driverSuppliedName); - }); - - checkedForDriverSuppliedName = true; - - if (driverSuppliedName.empty()) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"No driver-supplied name", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); - } - else - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Driver-supplied name found", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingWideString(driverSuppliedName.c_str(), "driver-supplied name") - ); - } - } - // this is a MIDI 1.0 byte format pin, so let's process it KsAggregateEndpointMidiPinDefinition pinDefinition{ }; @@ -1364,38 +1254,7 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( midiInputPinIndexForThisFilter++; } - // not being able to get the group index is fatal - RETURN_IF_FAILED(GetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingUInt8(pinDefinition.GroupIndex, "group index"), - TraceLoggingWideString(pinDefinition.DataFlowFromUserPerspective == - MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") - ); - - // This is where we build the proposed names - // ================================================= - - std::wstring customName = L""; // TODO - - endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( - pinDefinition.GroupIndex, - pinDefinition.DataFlowFromUserPerspective, - customName, - driverSuppliedName, - pinDefinition.FilterName, - pinDefinition.PinName, - pinDefinition.PortIndexWithinThisFilterAndDirection - ); - - endpointDefinition->MidiPins.push_back(pinDefinition); + pinListToAddTo.push_back(pinDefinition); TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1410,7 +1269,45 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( } - if (endpointDefinition == nullptr || endpointDefinition->MidiPins.size() == 0) + return S_OK; +} + + + + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( + DeviceWatcher /* watcher */, + DeviceInformation filterDevice +) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "added interface") + ); + + std::wstring transportCode(TRANSPORT_CODE); + + // Wrapper opens the handle internally. + KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); + RETURN_IF_FAILED(deviceHandleWrapper.Open()); + + std::shared_ptr endpointDefinition{ nullptr }; + + // get all the MIDI 1 pins. These are only partially processed because some things + // like group index and naming require the full context of all filters/pins on the + // parent device. + std::vector pinList{ }; + RETURN_IF_FAILED(GetMidi1FilterPins(filterDevice, pinList)); + + if (pinList.size() == 0) { TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1418,25 +1315,136 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"No MIDI 1.0 pins found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(L"No MIDI 1.0 pins for this filter device", MIDI_TRACE_EVENT_MESSAGE_FIELD), TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") ); + + return S_OK; } - else if (endpointDefinition->MidiPins.size() > 0) + + + // check to see if we already have an activated endpoint for this filter + auto parentInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); + RETURN_HR_IF(E_FAIL, parentInstanceId.empty()); + if (KSAEndpointForDeviceExists(parentInstanceId.c_str())) { TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, + MIDI_TRACE_EVENT_VERBOSE, TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Filter pin Enumeration Complete", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(L"KSA endpoint for this filter already activated. TEMP skipping.", MIDI_TRACE_EVENT_MESSAGE_FIELD), TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingUInt32(static_cast(endpointDefinition->MidiPins.size()), "total MIDI 1.0 pin count") + TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") ); - m_endpointCreationThreadWakeup.SetEvent(); + // TODO: We need to add this interface to the existing device + return S_OK; + // END TODO ============================================================== } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Endpoint for this device does not already exist. Proceed to creating new one.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") + ); + } + + + // if the endpointDefinition is null, that means we haven't found an existing + // activated endpoint definition we need to use, and so we proceed to check + // for an existing pending endpoint definition. If found, it's used. If not + // found, the function will create a new one for us to use, with all the + // endpoint-specific details (excluding pins) populated. + if (endpointDefinition == nullptr) + { + // first MIDI 1 pin we're processing for this interface + RETURN_IF_FAILED(FindOrCreateMasterEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); + RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); + + // add our new pins into the existing endpoint definition + endpointDefinition->MidiPins.insert(endpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); + pinList.clear(); // just make sure we don't use this one, accidentally + } + + // Get the driver-supplied name. This is needed for WinMM backwards compatibility + std::wstring driverSuppliedName{}; + + // Using lamba function to prevent handle from dissapearing when being used. + // we don't log the HRESULT because this is not critical and will often fail, + // adding unnecessary noise to the error logs + deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetKSDriverSuppliedName(h, driverSuppliedName); + }); + +#ifdef _DEBUG + if (!driverSuppliedName.empty()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Driver-supplied name found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(driverSuppliedName.c_str(), "driver-supplied name") + ); + } +#endif + + // At this point, we need to have *all* the pins for the endpoint, not just this filter + for (auto& pinDefinition : endpointDefinition->MidiPins) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(pinDefinition.FilterDeviceId) != + internal::NormalizeDeviceInstanceIdWStringCopy(filterDevice.Id().c_str())) + { + // only process the pins for this filter interface. We don't want to + // change anything that has already been built. But we do need the + // context of all pins when getting the group index. + continue; + } + + // Figure out the group index for the pin. This needs the context of the + // entire device. Failure to get the group index is fatal + RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingUInt8(pinDefinition.GroupIndex, "group index"), + TraceLoggingWideString(pinDefinition.DataFlowFromUserPerspective == + MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") + ); + + std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update + + // Build the name table entry for this individual pin + endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( + pinDefinition.GroupIndex, + pinDefinition.DataFlowFromUserPerspective, + customName, + driverSuppliedName, + pinDefinition.FilterName, + pinDefinition.PinName, + pinDefinition.PortIndexWithinThisFilterAndDirection + ); + } + + // we have an endpoint definition + m_endpointCreationThreadWakeup.SetEvent(); return S_OK; } diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h index 25fbfaee2..844acfa0a 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h @@ -56,6 +56,10 @@ struct KsAggregateEndpointDefinition std::vector MidiPins{ }; WindowsMidiServicesNamingLib::MidiEndpointNameTable EndpointNameTable{}; + + // new for 2603 fix, but does not impact the existing code + int8_t CurrentHighestMidiSourceGroupIndex{ -1 }; + int8_t CurrentHighestMidiDestinationGroupIndex{ -1 }; }; @@ -95,17 +99,22 @@ class CMidi2KSAggregateMidiEndpointManager : wil::critical_section m_pendingEndpointDefinitionsLock; std::vector> m_pendingEndpointDefinitions; + HRESULT FindOrCreateMasterEndpointDefinitionForFilterDevice( - _In_ DeviceInformation, - _In_ std::shared_ptr&); + _In_ DeviceInformation, + _In_ std::shared_ptr&); + HRESULT GetMidi1FilterPins( + _In_ DeviceInformation, + _In_ std::vector&); + bool KSAEndpointForDeviceExists( - _In_ std::wstring deviceInstanceId); + _In_ std::wstring deviceInstanceId); - HRESULT GetNextGroupIndex( - _In_ std::shared_ptr definition, - _In_ MidiFlow dataFlowFromUserPerspective, - _In_ uint8_t& groupIndex); + HRESULT IncrementAndGetNextGroupIndex( + _In_ std::shared_ptr definition, + _In_ MidiFlow dataFlowFromUserPerspective, + _In_ uint8_t& groupIndex); wil::unique_event_nothrow m_endpointCreationThreadWakeup; std::jthread m_endpointCreationThread; From 398a1b64ac768ac7c12ef6a6e74aff7b21e60bb8 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Sun, 8 Feb 2026 18:54:02 -0500 Subject: [PATCH 07/18] Now supports addition of loopMIDI ports post-enumeration --- .../Midi2.KSAggregateMidiEndpointManager.cpp | 802 +++++++++++++++++- .../Midi2.KSAggregateMidiEndpointManager.h | 30 +- 2 files changed, 798 insertions(+), 34 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index c39fe2a98..9f70d9219 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -595,6 +595,638 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint( } +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::UpdateNameTableWithCustomProperties( + std::shared_ptr masterEndpointDefinition, + std::shared_ptr customProperties) +{ + RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); + + for (auto const& pinEntry : masterEndpointDefinition->MidiPins) + { + if (customProperties != nullptr && + (customProperties->Midi1Destinations.size() > 0 || customProperties->Midi1Sources.size() > 0)) + { + if (pinEntry.PinDataFlow == MidiFlow::MidiFlowIn) + { + // message destination (output port), pin flow is In + if (auto customConfiguredName = customProperties->Midi1Destinations.find(pinEntry.GroupIndex); + customConfiguredName != customProperties->Midi1Destinations.end()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found custom name for a Midi 1 destination.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingWideString(customConfiguredName->second.Name.c_str(), "custom name"), + TraceLoggingUInt8(pinEntry.GroupIndex, "group index") + ); + + masterEndpointDefinition->EndpointNameTable.UpdateDestinationEntryCustomName(pinEntry.GroupIndex, customConfiguredName->second.Name); + } + } + else if (pinEntry.PinDataFlow == MidiFlow::MidiFlowOut) + { + // message source (input port), pin flow is Out + if (auto customConfiguredName = customProperties->Midi1Sources.find(pinEntry.GroupIndex); + customConfiguredName != customProperties->Midi1Sources.end()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found custom name for a Midi 1 source.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingWideString(customConfiguredName->second.Name.c_str(), "custom name"), + TraceLoggingUInt8(pinEntry.GroupIndex, "group index") + ); + + masterEndpointDefinition->EndpointNameTable.UpdateSourceEntryCustomName(pinEntry.GroupIndex, customConfiguredName->second.Name); + } + } + } + + } + + + + return S_OK; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::BuildPinsAndGroupTerminalBlocksPropertyData( + std::shared_ptr masterEndpointDefinition, + std::vector& pinMapPropertyData, + std::vector& groupTerminalBlocks) +{ + uint8_t currentBlockNumber{ 0 }; + std::vector pinMapEntries{ }; + + for (auto const& pin : masterEndpointDefinition->MidiPins) + { + RETURN_HR_IF(E_INVALIDARG, pin.FilterDeviceId.empty()); + + internal::GroupTerminalBlockInternal gtb; + + gtb.Number = ++currentBlockNumber; + gtb.GroupCount = 1; // always a single group for aggregate MIDI 1.0 devices + + PinMapEntryStagingEntry pinMapEntry{ }; + + pinMapEntry.PinId = pin.PinNumber; + pinMapEntry.FilterId = pin.FilterDeviceId; + pinMapEntry.PinDataFlow = pin.PinDataFlow; + + //MidiFlow flowFromUserPerspective; + + pinMapEntry.GroupIndex = pin.GroupIndex; + gtb.FirstGroupIndex = pin.GroupIndex; + + if (pin.PinDataFlow == MidiFlow::MidiFlowIn) // pin in, so user out : A MIDI Destination + { + gtb.Direction = MIDI_GROUP_TERMINAL_BLOCK_INPUT; // from the pin/gtb's perspective + + auto nameTableEntry = masterEndpointDefinition->EndpointNameTable.GetDestinationEntry(gtb.FirstGroupIndex); + if (nameTableEntry != nullptr && nameTableEntry->NewStyleName[0] != static_cast(0)) + { + gtb.Name = internal::TrimmedWStringCopy(nameTableEntry->NewStyleName); + } + } + else if (pin.PinDataFlow == MidiFlow::MidiFlowOut) // pin out, so user in : A MIDI Source + { + gtb.Direction = MIDI_GROUP_TERMINAL_BLOCK_OUTPUT; // from the pin/gtb's perspective + auto nameTableEntry = masterEndpointDefinition->EndpointNameTable.GetSourceEntry(gtb.FirstGroupIndex); + if (nameTableEntry != nullptr && nameTableEntry->NewStyleName[0] != static_cast(0)) + { + gtb.Name = internal::TrimmedWStringCopy(nameTableEntry->NewStyleName); + } + } + else + { + RETURN_IF_FAILED(E_INVALIDARG); + } + + // name fallback + if (gtb.Name.empty()) + { + gtb.Name = masterEndpointDefinition->EndpointName; + + if (gtb.FirstGroupIndex > 0) + { + gtb.Name += L" " + std::wstring{ gtb.FirstGroupIndex }; + } + } + + // default values as defined in the MIDI 2.0 USB spec + gtb.Protocol = 0x01; // midi 1.0 + gtb.MaxInputBandwidth = 0x0001; // 31.25 kbps + gtb.MaxOutputBandwidth = 0x0001; // 31.25 kbps + + groupTerminalBlocks.push_back(gtb); + pinMapEntries.push_back(pinMapEntry); + } + + // Write Pin Map Property + // ===================================================== + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Building pin map property", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + // build the pin map property value + KSAGGMIDI_PIN_MAP_PROPERTY_VALUE pinMap{ }; + + size_t totalStringSizesIncludingNulls{ 0 }; + for (auto const& entry : pinMapEntries) + { + totalStringSizesIncludingNulls += ((entry.FilterId.length() + 1) * sizeof(wchar_t)); + } + + size_t totalMemoryBytes{ + SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_VALUE_HEADER + + SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_ENTRY_WITHOUT_STRING * pinMapEntries.size() + + totalStringSizesIncludingNulls }; + + pinMapPropertyData.resize(totalMemoryBytes); + auto currentPos = pinMapPropertyData.data(); + + // header + auto pinMapHeader = (PKSAGGMIDI_PIN_MAP_PROPERTY_VALUE)currentPos; + pinMapHeader->TotalByteCount = (UINT32)totalMemoryBytes; + currentPos += SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_VALUE_HEADER; + + for (auto const& entry : pinMapEntries) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Processing Pin Map entry", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingUInt32(entry.PinId, "Pin Id"), + TraceLoggingWideString(entry.FilterId.c_str(), "Filter Id") + ); + + PKSAGGMIDI_PIN_MAP_PROPERTY_ENTRY propEntry = (PKSAGGMIDI_PIN_MAP_PROPERTY_ENTRY)currentPos; + + propEntry->ByteCount = (UINT)(SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_ENTRY_WITHOUT_STRING + ((entry.FilterId.length() + 1) * sizeof(wchar_t))); + propEntry->GroupIndex = entry.GroupIndex; + propEntry->PinDataFlow = entry.PinDataFlow; + propEntry->PinId = entry.PinId; + + if (!entry.FilterId.empty()) + { + wcscpy_s((wchar_t*)propEntry->FilterId, entry.FilterId.length() + 1, entry.FilterId.c_str()); + } + + currentPos += propEntry->ByteCount; + } + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"All pin map entries copied to property memory", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + + return S_OK; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpointV2( + std::shared_ptr masterEndpointDefinition +) +{ + RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + DEVPROP_BOOLEAN devPropTrue = DEVPROP_TRUE; + + // we require at least one valid pin + RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition->MidiPins.size() < 1); + + std::vector interfaceDevProperties; + + MIDIENDPOINTCOMMONPROPERTIES commonProperties{}; + commonProperties.TransportId = TRANSPORT_LAYER_GUID; + commonProperties.EndpointDeviceType = MidiEndpointDeviceType_Normal; + commonProperties.FriendlyName = masterEndpointDefinition->EndpointName.c_str(); + commonProperties.TransportCode = TRANSPORT_CODE; + commonProperties.EndpointName = masterEndpointDefinition->EndpointName.c_str(); + commonProperties.EndpointDescription = nullptr; + commonProperties.CustomEndpointName = nullptr; + commonProperties.CustomEndpointDescription = nullptr; + commonProperties.UniqueIdentifier = masterEndpointDefinition->SerialNumber.empty() ? nullptr : masterEndpointDefinition->SerialNumber.c_str(); + commonProperties.ManufacturerName = masterEndpointDefinition->ManufacturerName.empty() ? nullptr : masterEndpointDefinition->ManufacturerName.c_str(); + commonProperties.SupportedDataFormats = MidiDataFormats::MidiDataFormats_UMP; + commonProperties.NativeDataFormat = MidiDataFormats_ByteStream; + + UINT32 capabilities{ 0 }; + capabilities |= MidiEndpointCapabilities_SupportsMultiClient; + capabilities |= MidiEndpointCapabilities_GenerateIncomingTimestamps; + capabilities |= MidiEndpointCapabilities_SupportsMidi1Protocol; + commonProperties.Capabilities = (MidiEndpointCapabilities)capabilities; + + + std::vector pinMapPropertyData; + std::vector groupTerminalBlocks{ }; + std::vector nameTablePropertyData; + + RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData( + masterEndpointDefinition, + pinMapPropertyData, + groupTerminalBlocks)); + + + interfaceDevProperties.push_back({ { DEVPKEY_KsAggMidiGroupPinMap, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BINARY, static_cast(pinMapPropertyData.size()), pinMapPropertyData.data() }); + + + // Write Group Terminal Block Property + // ===================================================== + + std::vector groupTerminalBlockPropertyData{}; + + if (internal::WriteGroupTerminalBlocksToPropertyDataPointer(groupTerminalBlocks, groupTerminalBlockPropertyData)) + { + interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BINARY, static_cast(groupTerminalBlockPropertyData.size()), + (PVOID)groupTerminalBlockPropertyData.data() }); + } + else + { + // write empty data + } + + + // Fold in custom properties, including MIDI 1 port names and naming approach + // =============================================================================== + + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{}; + matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->EndpointDeviceInstanceId); + matchCriteria.UsbVendorId = masterEndpointDefinition->VID; + matchCriteria.UsbProductId = masterEndpointDefinition->PID; + matchCriteria.UsbSerialNumber = masterEndpointDefinition->SerialNumber; + matchCriteria.TransportSuppliedEndpointName = masterEndpointDefinition->EndpointName; + + auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria); + + // rebuild the name table, using the custom properties if present + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); + + std::wstring customName{ }; + std::wstring customDescription{ }; + if (customProperties != nullptr) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"), + TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count") + ); + + if (!customProperties->Name.empty()) + { + customName = customProperties->Name; + commonProperties.CustomEndpointName = customName.c_str(); + } + + if (!customProperties->Description.empty()) + { + customDescription = customProperties->Description; + commonProperties.CustomEndpointDescription = customDescription.c_str(); + } + + // this includes image, the Midi 1 naming approach, etc. + customProperties->WriteNonCommonProperties(interfaceDevProperties); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) + ); + } + + // Write Name table property, folding in the custom names we discovered earlier + // =============================================================================================== + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); + masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + + + // Write USB VID/PID Data + // ===================================================== + + if (masterEndpointDefinition->VID > 0) + { + interfaceDevProperties.push_back({ { PKEY_MIDI_UsbVID, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_UINT16, static_cast(sizeof(UINT16)), (PVOID)&masterEndpointDefinition->VID }); + } + else + { + interfaceDevProperties.push_back({ { PKEY_MIDI_UsbVID, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_EMPTY, 0, nullptr }); + } + + if (masterEndpointDefinition->PID > 0) + { + interfaceDevProperties.push_back({ { PKEY_MIDI_UsbPID, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_UINT16, static_cast(sizeof(UINT16)), (PVOID)&masterEndpointDefinition->PID }); + } + else + { + interfaceDevProperties.push_back({ { PKEY_MIDI_UsbPID, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_EMPTY, 0, nullptr }); + } + + + // Despite being a MIDI 1 device, we present as a UMP endpoint, so we need to set + // this property so the service can create the MIDI 1 ports without waiting for + // function blocks/discovery to complete or timeout (which it never will) + interfaceDevProperties.push_back({ { PKEY_MIDI_EndpointDiscoveryProcessComplete, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BOOLEAN, (ULONG)sizeof(devPropTrue), (PVOID)&devPropTrue }); + + SW_DEVICE_CREATE_INFO createInfo{ }; + + createInfo.cbSize = sizeof(createInfo); + createInfo.pszInstanceId = masterEndpointDefinition->EndpointDeviceInstanceId.c_str(); + createInfo.CapabilityFlags = SWDeviceCapabilitiesNone; + createInfo.pszDeviceDescription = masterEndpointDefinition->EndpointName.c_str(); + + // Call the device manager and finish the creation + + HRESULT swdCreationResult; + wil::unique_cotaskmem_string newDeviceInterfaceId; + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Activating endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + // set to true if we only want to create UMP endpoints + bool umpOnly = false; + + LOG_IF_FAILED( + swdCreationResult = m_midiDeviceManager->ActivateEndpoint( + masterEndpointDefinition->ParentDeviceInstanceId.c_str(), + umpOnly, + MidiFlow::MidiFlowBidirectional, // bidi only for the UMP endpoint + &commonProperties, + (ULONG)interfaceDevProperties.size(), + (ULONG)0, + interfaceDevProperties.data(), + nullptr, + &createInfo, + &newDeviceInterfaceId) + ); + + if (SUCCEEDED(swdCreationResult)) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Aggregate UMP endpoint created", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingWideString(newDeviceInterfaceId.get(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) + ); + + // return new device interface id + masterEndpointDefinition->EndpointDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(std::wstring{ newDeviceInterfaceId.get() }); + + auto lock = m_availableEndpointDefinitionsLock.lock(); + + // Add to internal endpoint manager + m_availableEndpointDefinitionsV2.insert_or_assign( + internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->ParentDeviceInstanceId), + masterEndpointDefinition); + + return swdCreationResult; + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Aggregate UMP endpoint creation failed", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingHResult(swdCreationResult, MIDI_TRACE_EVENT_HRESULT_FIELD) + ); + + return swdCreationResult; + } +} + + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::UpdateExistingMidiUmpEndpointWithFilterChanges( + std::shared_ptr masterEndpointDefinition +) +{ + RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + // we require at least one valid pin + RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition->MidiPins.size() < 1); + + std::vector interfaceDevProperties{ }; + + std::vector pinMapPropertyData; + std::vector groupTerminalBlocks{ }; + std::vector nameTablePropertyData; + + // update the pin map to have all the existing pins + // plus the new pins. Update Group Terminal Blocks at the same time. + RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData( + masterEndpointDefinition, + pinMapPropertyData, + groupTerminalBlocks)); + + + // Write Pin Map Property + // ===================================================== + interfaceDevProperties.push_back({ { DEVPKEY_KsAggMidiGroupPinMap, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BINARY, static_cast(pinMapPropertyData.size()), pinMapPropertyData.data() }); + + + // Write Group Terminal Block Property + // ===================================================== + + std::vector groupTerminalBlockData; + + if (internal::WriteGroupTerminalBlocksToPropertyDataPointer(groupTerminalBlocks, groupTerminalBlockData)) + { + interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BINARY, (ULONG)groupTerminalBlockData.size(), (PVOID)groupTerminalBlockData.data() }); + } + else + { + // write empty data + } + + + // Fold in custom properties, including MIDI 1 port names and naming approach + // =============================================================================== + + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{}; + matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->EndpointDeviceInstanceId); + matchCriteria.UsbVendorId = masterEndpointDefinition->VID; + matchCriteria.UsbProductId = masterEndpointDefinition->PID; + matchCriteria.UsbSerialNumber = masterEndpointDefinition->SerialNumber; + matchCriteria.TransportSuppliedEndpointName = masterEndpointDefinition->EndpointName; + + auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria); + + std::wstring customName{ }; + std::wstring customDescription{ }; + if (customProperties != nullptr) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"), + TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count") + ); + + // this includes image, the Midi 1 naming approach, etc. + customProperties->WriteNonCommonProperties(interfaceDevProperties); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) + ); + } + + // store the property data for the name table + masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + + + // Write Name table property, folding in the custom names we discovered earlier + // =============================================================================================== + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); + masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + + HRESULT updateResult{}; + + LOG_IF_FAILED(updateResult = m_midiDeviceManager->UpdateEndpointProperties( + masterEndpointDefinition->EndpointDeviceId.c_str(), + static_cast(interfaceDevProperties.size()), + interfaceDevProperties.data() + )); + + + if (SUCCEEDED(updateResult)) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Aggregate UMP endpoint updated with new filter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceId.c_str(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) + ); + + auto lock = m_availableEndpointDefinitionsLock.lock(); + + // Add to internal endpoint manager + m_availableEndpointDefinitionsV2.insert_or_assign( + internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->ParentDeviceInstanceId), + masterEndpointDefinition); + + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Aggregate UMP endpoint update failed", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingHResult(updateResult, MIDI_TRACE_EVENT_HRESULT_FIELD) + ); + } + + return updateResult; +} + + HRESULT GetPinName(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ std::wstring& pinName) { @@ -817,7 +1449,29 @@ ParseParentIdIntoVidPidSerial( _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::FindOrCreateMasterEndpointDefinitionForFilterDevice( +CMidi2KSAggregateMidiEndpointManager::FindActivatedMasterEndpointDefinitionForFilterDevice( + std::wstring parentDeviceInstanceId, + std::shared_ptr& endpointDefinition +) +{ + for (auto const& entry : m_availableEndpointDefinitionsV2) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->ParentDeviceInstanceId) == + internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) + { + endpointDefinition = entry.second; + + return S_OK; + } + } + + return E_NOTFOUND; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::FindOrCreatePendingMasterEndpointDefinitionForFilterDevice( DeviceInformation filterDevice, std::shared_ptr& endpointDefinition ) @@ -874,8 +1528,6 @@ CMidi2KSAggregateMidiEndpointManager::FindOrCreateMasterEndpointDefinitionForFil auto newEndpointDefinition = std::make_shared(); RETURN_HR_IF_NULL(E_OUTOFMEMORY, newEndpointDefinition); -// auto systemDevicesParent = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.Parent", parentDevice, L""); - newEndpointDefinition->ParentDeviceName = parentDevice.Name(); newEndpointDefinition->EndpointName = parentDevice.Name(); newEndpointDefinition->ParentDeviceInstanceId = parentDevice.Id(); @@ -1037,7 +1689,7 @@ void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( m_pendingEndpointDefinitions.erase(m_pendingEndpointDefinitions.begin()); // create the endpoint - LOG_IF_FAILED(CreateMidiUmpEndpoint(*ep)); + LOG_IF_FAILED(CreateMidiUmpEndpointV2(ep)); } TraceLoggingWrite( @@ -1096,9 +1748,9 @@ _Use_decl_annotations_ bool CMidi2KSAggregateMidiEndpointManager::KSAEndpointForDeviceExists( _In_ std::wstring parentDeviceInstanceId) { - for (auto const& entry : m_availableEndpointDefinitions) + for (auto const& entry : m_availableEndpointDefinitionsV2) { - if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second.ParentDeviceInstanceId) == + if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->ParentDeviceInstanceId) == internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) { return true; @@ -1273,7 +1925,58 @@ CMidi2KSAggregateMidiEndpointManager::GetMidi1FilterPins( } +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::UpdateNewPinDefinitions( + std::wstring filterDeviceid, + std::wstring driverSuppliedName, + std::shared_ptr endpointDefinition) +{ + // At this point, we need to have *all* the pins for the endpoint, not just this filter + for (auto& pinDefinition : endpointDefinition->MidiPins) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(pinDefinition.FilterDeviceId) != + internal::NormalizeDeviceInstanceIdWStringCopy(filterDeviceid)) + { + // only process the pins for this filter interface. We don't want to + // change anything that has already been built. But we do need the + // context of all pins when getting the group index. + continue; + } + + // Figure out the group index for the pin. This needs the context of the + // entire device. Failure to get the group index is fatal + RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDeviceid.c_str(), "filter device id"), + TraceLoggingUInt8(pinDefinition.GroupIndex, "group index"), + TraceLoggingWideString(pinDefinition.DataFlowFromUserPerspective == + MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") + ); + + std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update + + // Build the name table entry for this individual pin + endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( + pinDefinition.GroupIndex, + pinDefinition.DataFlowFromUserPerspective, + customName, + driverSuppliedName, + pinDefinition.FilterName, + pinDefinition.PinName, + pinDefinition.PortIndexWithinThisFilterAndDirection + ); + } + return S_OK; +} _Use_decl_annotations_ @@ -1322,10 +2025,21 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( return S_OK; } - - // check to see if we already have an activated endpoint for this filter auto parentInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); RETURN_HR_IF(E_FAIL, parentInstanceId.empty()); + + // Driver-supplied name. This is needed for WinMM backwards compatibility + std::wstring driverSuppliedName{}; + + // Using lamba function to prevent handle from dissapearing when being used. + // we don't log the HRESULT because this is not critical and will often fail, + // adding unnecessary noise to the error logs + deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetKSDriverSuppliedName(h, driverSuppliedName); + }); + + + // check to see if we already have an *activated* endpoint for this filter if (KSAEndpointForDeviceExists(parentInstanceId.c_str())) { TraceLoggingWrite( @@ -1339,9 +2053,19 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") ); - // TODO: We need to add this interface to the existing device + std::shared_ptr existingActivatedEndpointDefinition { nullptr }; + + // first MIDI 1 pin we're processing for this interface + RETURN_IF_FAILED(FindActivatedMasterEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); + RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); + + // add our new pins into the existing endpoint definition + existingActivatedEndpointDefinition->MidiPins.insert(existingActivatedEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); + RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), driverSuppliedName, existingActivatedEndpointDefinition)); + + RETURN_IF_FAILED(UpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition)); + return S_OK; - // END TODO ============================================================== } else { @@ -1366,7 +2090,7 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( if (endpointDefinition == nullptr) { // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindOrCreateMasterEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); + RETURN_IF_FAILED(FindOrCreatePendingMasterEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); // add our new pins into the existing endpoint definition @@ -1374,15 +2098,6 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( pinList.clear(); // just make sure we don't use this one, accidentally } - // Get the driver-supplied name. This is needed for WinMM backwards compatibility - std::wstring driverSuppliedName{}; - - // Using lamba function to prevent handle from dissapearing when being used. - // we don't log the HRESULT because this is not critical and will often fail, - // adding unnecessary noise to the error logs - deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetKSDriverSuppliedName(h, driverSuppliedName); - }); #ifdef _DEBUG if (!driverSuppliedName.empty()) @@ -1958,24 +2673,49 @@ winrt::hstring CMidi2KSAggregateMidiEndpointManager::FindMatchingInstantiatedEnd { criteria.Normalize(); - for (auto const& def : m_availableEndpointDefinitions) + if (NewMidiFeatureUpdateKsa2603Enabled()) { - WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; + for (auto const& def : m_availableEndpointDefinitionsV2) + { + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; - available.DeviceInstanceId = def.second.EndpointDeviceInstanceId; - available.EndpointDeviceId = def.second.EndpointDeviceId; - available.UsbVendorId = def.second.VID; - available.UsbProductId = def.second.PID; - available.UsbSerialNumber = def.second.SerialNumber; - available.TransportSuppliedEndpointName = def.second.EndpointName; - available.DeviceManufacturerName = def.second.ManufacturerName; + available.DeviceInstanceId = def.second->EndpointDeviceInstanceId; + available.EndpointDeviceId = def.second->EndpointDeviceId; + available.UsbVendorId = def.second->VID; + available.UsbProductId = def.second->PID; + available.UsbSerialNumber = def.second->SerialNumber; + available.TransportSuppliedEndpointName = def.second->EndpointName; + available.DeviceManufacturerName = def.second->ManufacturerName; - if (available.Matches(criteria)) + if (available.Matches(criteria)) + { + return available.EndpointDeviceId; + } + } + } + else + { + for (auto const& def : m_availableEndpointDefinitions) { - return available.EndpointDeviceId; + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; + + available.DeviceInstanceId = def.second.EndpointDeviceInstanceId; + available.EndpointDeviceId = def.second.EndpointDeviceId; + available.UsbVendorId = def.second.VID; + available.UsbProductId = def.second.PID; + available.UsbSerialNumber = def.second.SerialNumber; + available.TransportSuppliedEndpointName = def.second.EndpointName; + available.DeviceManufacturerName = def.second.ManufacturerName; + + if (available.Matches(criteria)) + { + return available.EndpointDeviceId; + } } } + + return L""; } diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h index 844acfa0a..289c05e98 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h @@ -78,7 +78,8 @@ class CMidi2KSAggregateMidiEndpointManager : private: STDMETHOD(CreateMidiUmpEndpoint)(_In_ KsAggregateEndpointDefinition& masterEndpointDefinition); - + STDMETHOD(CreateMidiUmpEndpointV2)(_In_ std::shared_ptr masterEndpointDefinition); + HRESULT OnDeviceAdded(_In_ DeviceWatcher, _In_ DeviceInformation); HRESULT OnDeviceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); HRESULT OnDeviceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); @@ -96,11 +97,16 @@ class CMidi2KSAggregateMidiEndpointManager : wil::critical_section m_availableEndpointDefinitionsLock; std::map m_availableEndpointDefinitions; - + std::map> m_availableEndpointDefinitionsV2; // for 2603 CFR update only + wil::critical_section m_pendingEndpointDefinitionsLock; std::vector> m_pendingEndpointDefinitions; - HRESULT FindOrCreateMasterEndpointDefinitionForFilterDevice( + HRESULT FindActivatedMasterEndpointDefinitionForFilterDevice( + _In_ std::wstring parentDeviceInstanceId, + _In_ std::shared_ptr&); + + HRESULT FindOrCreatePendingMasterEndpointDefinitionForFilterDevice( _In_ DeviceInformation, _In_ std::shared_ptr&); @@ -116,10 +122,28 @@ class CMidi2KSAggregateMidiEndpointManager : _In_ MidiFlow dataFlowFromUserPerspective, _In_ uint8_t& groupIndex); + HRESULT UpdateNewPinDefinitions( + _In_ std::wstring filterDeviceid, + _In_ std::wstring driverSuppliedName, + _In_ std::shared_ptr endpointDefinition); + + HRESULT BuildPinsAndGroupTerminalBlocksPropertyData( + _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::vector& pinMapPropertyData, + _In_ std::vector& groupTerminalBlocks); + + HRESULT UpdateNameTableWithCustomProperties( + _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::shared_ptr customProperties); + wil::unique_event_nothrow m_endpointCreationThreadWakeup; std::jthread m_endpointCreationThread; void EndpointCreationThreadWorker(_In_ std::stop_token token); + HRESULT UpdateExistingMidiUmpEndpointWithFilterChanges( + _In_ std::shared_ptr masterEndpointDefinition); + + DeviceWatcher m_watcher{0}; winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Added_revoker m_DeviceAdded; From 892b20529e2e39a0bf1a1e8af353d3e8c342f57b Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Sun, 8 Feb 2026 19:56:02 -0500 Subject: [PATCH 08/18] Add support for removing loopMIDI ports --- .../Midi2.KSAggregateMidiEndpointManager.cpp | 85 ++++++++++++++++--- 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index 9f70d9219..0e11a39c1 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -2048,7 +2048,7 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"KSA endpoint for this filter already activated. TEMP skipping.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(L"KSA endpoint for this filter already activated. Updating it.", MIDI_TRACE_EVENT_MESSAGE_FIELD), TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") ); @@ -2179,22 +2179,81 @@ CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceRemoved( TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "added interface") + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "removed interface") ); - // Flow for interface REMOVED - // - Check for the interface on the existing device - // - Update pin map to remove all entries with that interface - // - If this is the last interface, then remove the device. - // - If not the last interface: - // - Reset the UPDATE_TIMEOUT timeout for this device. If no other removals come through for this device during the timeout: - // - Rebuild pin map, maintaining existing numbers where possible - // - Rebuild GTBs, maintaining existing numbers where possible - // - Recalculate name table - // - call MidiDeviceManager::UpdateEndpointProperties. That will also recalculate MIDI 1 ports - // + std::wstring removedFilterDeviceId{ internal::NormalizeDeviceInstanceIdWStringCopy(deviceInterfaceUpdate.Id().c_str()) }; + + // find an active device with this filter + + std::shared_ptr endpointDefinition{ nullptr }; + + + + for (auto& endpointListIterator : m_availableEndpointDefinitionsV2) + { + // check pins for this filter + for (auto& pin: endpointListIterator.second->MidiPins) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(pin.FilterDeviceId) == removedFilterDeviceId) + { + endpointDefinition = endpointListIterator.second; + break; + } + } + + } + + if (endpointDefinition != nullptr) + { + bool done { false }; + + while (!done) + { + auto foundIt = std::find_if( + endpointDefinition->MidiPins.begin(), + endpointDefinition->MidiPins.end(), + [&removedFilterDeviceId](KsAggregateEndpointMidiPinDefinition& pin) { return internal::NormalizeDeviceInstanceIdWStringCopy(pin.FilterDeviceId) == removedFilterDeviceId; } + ); + + if (foundIt != endpointDefinition->MidiPins.end()) + { + // erase the pin definition with this + endpointDefinition->MidiPins.erase(foundIt); + } + else + { + // we've removed all the pins for this interface + done = true; + } + } + + if (endpointDefinition->MidiPins.size() > 0) + { + // we've removed all the pins for this interface, but there are still + // pins left, so now it's time to update the endpoint + + + // TODO: Need to cache the name from the driver/registry so we don't have to do a lookup here. + // update remaining pins in existing endpoint definition + RETURN_IF_FAILED(UpdateNewPinDefinitions(removedFilterDeviceId, L"", endpointDefinition)); + RETURN_IF_FAILED(UpdateExistingMidiUmpEndpointWithFilterChanges(endpointDefinition)); + } + else + { + auto lock = m_availableEndpointDefinitionsLock.lock(); + // notify the device manager using the InstanceId for this midi device + RETURN_IF_FAILED(m_midiDeviceManager->RemoveEndpoint( + internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->EndpointDeviceInstanceId).c_str())); + + // remove the endpoint from the list + + m_availableEndpointDefinitionsV2.erase(internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->ParentDeviceInstanceId)); + } + } return S_OK; } From ab7bad196eedd35d93eec24cb1f6a7b452041f30 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Tue, 10 Feb 2026 10:56:03 -0500 Subject: [PATCH 09/18] Working on > 16 ports limitation --- .../Midi2.KSAggregateMidiEndpointManager.cpp | 8 +- .../Midi2.KSAggregateMidiEndpointManager.h | 76 +++++++++++++++---- 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index 0e11a39c1..599ba4ec1 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -1449,9 +1449,9 @@ ParseParentIdIntoVidPidSerial( _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::FindActivatedMasterEndpointDefinitionForFilterDevice( +CMidi2KSAggregateMidiEndpointManager::FindActivatedEndpointDefinitionForFilterDevice( std::wstring parentDeviceInstanceId, - std::shared_ptr& endpointDefinition + std::shared_ptr& endpointDefinition ) { for (auto const& entry : m_availableEndpointDefinitionsV2) @@ -1471,9 +1471,9 @@ CMidi2KSAggregateMidiEndpointManager::FindActivatedMasterEndpointDefinitionForFi _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::FindOrCreatePendingMasterEndpointDefinitionForFilterDevice( +CMidi2KSAggregateMidiEndpointManager::FindOrCreatePendingEndpointDefinitionForFilterDevice( DeviceInformation filterDevice, - std::shared_ptr& endpointDefinition + std::shared_ptr& endpointDefinition ) { TraceLoggingWrite( diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h index 289c05e98..2e5bbdfdd 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h @@ -56,13 +56,52 @@ struct KsAggregateEndpointDefinition std::vector MidiPins{ }; WindowsMidiServicesNamingLib::MidiEndpointNameTable EndpointNameTable{}; +}; + + + +// new structures because we need to be able to pull together +// virtual endpoints, which have greater than 16 ins and/or outs +// and so need the creation of multiple endpoints. Without the +// new 2603 approach, only 16 in and out ports are available +// per parent device (teVirtualMidi in this case). Also impacts +// loopBE30. + +struct KsAggregateEndpointDefinitionV2 +{ + std::wstring EndpointDeviceId{}; + + std::wstring EndpointName{}; + std::wstring EndpointDeviceInstanceId{}; + + std::vector MidiPins{ }; + + WindowsMidiServicesNamingLib::MidiEndpointNameTable EndpointNameTable{}; - // new for 2603 fix, but does not impact the existing code int8_t CurrentHighestMidiSourceGroupIndex{ -1 }; int8_t CurrentHighestMidiDestinationGroupIndex{ -1 }; }; +struct KsAggregateParentDeviceDefinitionV2 +{ + std::wstring DeviceName{}; + std::wstring DeviceInstanceId{}; + std::wstring DriverSuppliedDeviceName{}; // value from registry. Required for WinMM classic naming + + uint16_t VID{ 0 }; // USB-only + uint16_t PID{ 0 }; // USB-only + std::wstring SerialNumber{}; + + std::wstring ManufacturerName{}; + + std::vector Endpoints{ }; // most devices will have just one endpoint, but virtual can have > 1 +}; + + + + + class CMidi2KSAggregateMidiEndpointManager : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, @@ -86,29 +125,33 @@ class CMidi2KSAggregateMidiEndpointManager : HRESULT OnDeviceStopped(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); HRESULT OnEnumerationCompleted(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); - // new interface-based approach - HRESULT OnFilterDeviceInterfaceAdded(_In_ DeviceWatcher, _In_ DeviceInformation); - HRESULT OnFilterDeviceInterfaceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); - HRESULT OnFilterDeviceInterfaceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); - - wil::com_ptr_nothrow m_midiDeviceManager; wil::com_ptr_nothrow m_midiProtocolManager; wil::critical_section m_availableEndpointDefinitionsLock; std::map m_availableEndpointDefinitions; - std::map> m_availableEndpointDefinitionsV2; // for 2603 CFR update only wil::critical_section m_pendingEndpointDefinitionsLock; std::vector> m_pendingEndpointDefinitions; - HRESULT FindActivatedMasterEndpointDefinitionForFilterDevice( + + + // new interface-based approachfor 2603 CFR update + HRESULT OnFilterDeviceInterfaceAdded(_In_ DeviceWatcher, _In_ DeviceInformation); + HRESULT OnFilterDeviceInterfaceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); + HRESULT OnFilterDeviceInterfaceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); + + std::map> m_availableEndpointDefinitionsV2; + std::vector> m_pendingEndpointDefinitionsV2; + + + HRESULT FindActivatedEndpointDefinitionForFilterDevice( _In_ std::wstring parentDeviceInstanceId, - _In_ std::shared_ptr&); + _In_ std::shared_ptr&); - HRESULT FindOrCreatePendingMasterEndpointDefinitionForFilterDevice( + HRESULT FindOrCreatePendingEndpointDefinitionForFilterDevice( _In_ DeviceInformation, - _In_ std::shared_ptr&); + _In_ std::shared_ptr&); HRESULT GetMidi1FilterPins( _In_ DeviceInformation, @@ -118,22 +161,22 @@ class CMidi2KSAggregateMidiEndpointManager : _In_ std::wstring deviceInstanceId); HRESULT IncrementAndGetNextGroupIndex( - _In_ std::shared_ptr definition, + _In_ std::shared_ptr definition, _In_ MidiFlow dataFlowFromUserPerspective, _In_ uint8_t& groupIndex); HRESULT UpdateNewPinDefinitions( _In_ std::wstring filterDeviceid, _In_ std::wstring driverSuppliedName, - _In_ std::shared_ptr endpointDefinition); + _In_ std::shared_ptr endpointDefinition); HRESULT BuildPinsAndGroupTerminalBlocksPropertyData( - _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::shared_ptr masterEndpointDefinition, _In_ std::vector& pinMapPropertyData, _In_ std::vector& groupTerminalBlocks); HRESULT UpdateNameTableWithCustomProperties( - _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::shared_ptr masterEndpointDefinition, _In_ std::shared_ptr customProperties); wil::unique_event_nothrow m_endpointCreationThreadWakeup; @@ -145,6 +188,7 @@ class CMidi2KSAggregateMidiEndpointManager : + DeviceWatcher m_watcher{0}; winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Added_revoker m_DeviceAdded; winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Removed_revoker m_DeviceRemoved; From b6a5b713a1e21e79ee0b580766d96deab9aeb18f Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Tue, 10 Feb 2026 18:06:33 -0500 Subject: [PATCH 10/18] Change feature staging to match internal --- ...ure_Servicing_MIDI2VirtualPortDriversFix.h | 18 ++++++++++++ .../Midi2.KSAggregateMidi.cpp | 5 ++-- .../Midi2.KSAggregateMidiEndpointManager.cpp | 29 +++++++++++++++---- src/api/Transport/KSAggregateTransport/pch.h | 9 ------ 4 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 src/api/Inc/Feature_Servicing_MIDI2VirtualPortDriversFix.h diff --git a/src/api/Inc/Feature_Servicing_MIDI2VirtualPortDriversFix.h b/src/api/Inc/Feature_Servicing_MIDI2VirtualPortDriversFix.h new file mode 100644 index 000000000..4c166d7f9 --- /dev/null +++ b/src/api/Inc/Feature_Servicing_MIDI2VirtualPortDriversFix.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License +// ============================================================================ +// This is part of the Windows MIDI Services App API and should be used +// in your Windows application via an official binary distribution. +// Further information: https://aka.ms/midi +// ============================================================================ + +#pragma once + +class Feature_Servicing_MIDI2VirtualPortDriversFix +{ +public: + static bool IsEnabled() + { + return true; + } +}; \ No newline at end of file diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp index 372982727..a383d1263 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp @@ -11,6 +11,7 @@ #include "ump_iterator.h" +#include "Feature_Servicing_MIDI2VirtualPortDriversFix.h" _Use_decl_annotations_ HRESULT @@ -207,7 +208,7 @@ CMidi2KSAggregateMidi::Initialize( // needed for internal consumption. Gary to replace this with feature enablement check // defined in pch.h - if (NewMidiFeatureUpdateKsa2603Enabled()) + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { auto initResult = proxy->Initialize( @@ -293,7 +294,7 @@ CMidi2KSAggregateMidi::Initialize( // needed for internal consumption. Gary to replace this with feature enablement check // defined in pch.h - if (NewMidiFeatureUpdateKsa2603Enabled()) + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { auto initResult = proxy->Initialize( diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index 599ba4ec1..00ea3207d 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -13,6 +13,9 @@ #include // for the string stream in parsing of VID/PID/Serial from parent id #include // for getline for string parsing of VID/PID/Serial from parent id +#include "Feature_Servicing_MIDI2VirtualPortDriversFix.h" + + using namespace wil; using namespace winrt::Windows::Devices::Enumeration; using namespace winrt::Windows::Foundation; @@ -46,7 +49,7 @@ CMidi2KSAggregateMidiEndpointManager::Initialize( // needed for internal consumption. Gary to replace this with feature enablement check // defined in pch.h - if (NewMidiFeatureUpdateKsa2603Enabled()) + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { DWORD individualInterfaceEnumTimeoutMS{ DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; if (SUCCEEDED(wil::reg::get_value_dword_nothrow(HKEY_LOCAL_MACHINE, MIDI_ROOT_REG_KEY, KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE, &individualInterfaceEnumTimeoutMS))) @@ -126,7 +129,7 @@ CMidi2KSAggregateMidiEndpointManager::Initialize( // Wait for everything to be created so that they're available immediately after service start. m_EnumerationCompleted.wait(INITIAL_ENUMERATION_TIMEOUT_MS); - if (NewMidiFeatureUpdateKsa2603Enabled()) + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { if (m_pendingEndpointDefinitions.size() > 0) { @@ -1310,7 +1313,7 @@ CMidi2KSAggregateMidiEndpointManager::GetKSDriverSuppliedName(HANDLE hInstantiat &countBytesReturned ); - if (NewMidiFeatureUpdateKsa2603Enabled()) + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { // changed to not log the failure here. Failures are expected for many devices, and it's adding noise to error logs if (FAILED(hrComponent)) @@ -1979,6 +1982,22 @@ CMidi2KSAggregateMidiEndpointManager::UpdateNewPinDefinitions( } + + +HRESULT +PopulatePinKSDataFormats(HANDLE filterHandle, Some_vector_of_pin_format_structs) +{ + //Try this, it should be a fairly easy thing to add to your change. + // retrieve the : + //KSPROPSETID_Pin, + // KSPROPERTY_PIN_DATARANGES, + + // limit to pins with(pKsDataFormat->MajorFormat == KSDATAFORMAT_TYPE_MUSIC) + // + // Retrieval is going to follow the same ksmultipleitemp pattern as KSPROPERTY_MIDI2_GROUP_TERMINAL_BLOCKS +} + + _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( @@ -2732,7 +2751,7 @@ winrt::hstring CMidi2KSAggregateMidiEndpointManager::FindMatchingInstantiatedEnd { criteria.Normalize(); - if (NewMidiFeatureUpdateKsa2603Enabled()) + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { for (auto const& def : m_availableEndpointDefinitionsV2) { @@ -2790,7 +2809,7 @@ CMidi2KSAggregateMidiEndpointManager::Shutdown() TraceLoggingPointer(this, "this") ); - if (NewMidiFeatureUpdateKsa2603Enabled()) + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { m_endpointCreationThread.request_stop(); m_endpointCreationThreadWakeup.SetEvent(); diff --git a/src/api/Transport/KSAggregateTransport/pch.h b/src/api/Transport/KSAggregateTransport/pch.h index f34d0fb1e..7aa72d432 100644 --- a/src/api/Transport/KSAggregateTransport/pch.h +++ b/src/api/Transport/KSAggregateTransport/pch.h @@ -113,15 +113,6 @@ namespace internal = ::WindowsMidiServicesInternal; #include "Midi2UMP2BSTransform.h" #include "Midi2UMP2BSTransform_i.c" - -// this gets replaced with the internal CFR check when pulled into Windows repo -inline bool NewMidiFeatureUpdateKsa2603Enabled() { return true; } - - - - - - class CMidi2KSAggregateMidiEndpointManager; class CMidi2KSAggregateMidiInProxy; class CMidi2KSAggregateMidiOutProxy; From 5e03ca40e2cd908ff7f879a8b30ee6cfa9a77fec Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Sat, 14 Feb 2026 20:57:52 -0500 Subject: [PATCH 11/18] Working on KSA updates --- .../Midi2.KSAggregateMidi.cpp | 2 - .../Midi2.KSAggregateMidiEndpointManager.cpp | 1926 ++-------------- .../Midi2.KSAggregateMidiEndpointManager.h | 110 +- .../Midi2.KSAggregateMidiEndpointManager2.cpp | 1962 +++++++++++++++++ .../Midi2.KSAggregateMidiEndpointManager2.h | 179 ++ .../Midi2.KSAggregateTransport.cpp | 14 +- .../Midi2.KSAggregateTransport.vcxproj | 6 +- ...Midi2.KSAggregateTransport.vcxproj.filters | 10 +- .../KSAggregateTransport/TransportState.cpp | 12 +- .../KSAggregateTransport/TransportState.h | 26 +- src/api/Transport/KSAggregateTransport/pch.h | 4 + 11 files changed, 2370 insertions(+), 1881 deletions(-) create mode 100644 src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp create mode 100644 src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp index a383d1263..219425de8 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidi.cpp @@ -11,8 +11,6 @@ #include "ump_iterator.h" -#include "Feature_Servicing_MIDI2VirtualPortDriversFix.h" - _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidi::Initialize( diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp index 00ea3207d..b2a19b9fb 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.cpp @@ -13,9 +13,6 @@ #include // for the string stream in parsing of VID/PID/Serial from parent id #include // for getline for string parsing of VID/PID/Serial from parent id -#include "Feature_Servicing_MIDI2VirtualPortDriversFix.h" - - using namespace wil; using namespace winrt::Windows::Devices::Enumeration; using namespace winrt::Windows::Foundation; @@ -47,104 +44,37 @@ CMidi2KSAggregateMidiEndpointManager::Initialize( RETURN_IF_FAILED(midiDeviceManager->QueryInterface(__uuidof(IMidiDeviceManager), (void**)&m_midiDeviceManager)); RETURN_IF_FAILED(midiEndpointProtocolManager->QueryInterface(__uuidof(IMidiEndpointProtocolManager), (void**)&m_midiProtocolManager)); - // needed for internal consumption. Gary to replace this with feature enablement check - // defined in pch.h - if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) - { - DWORD individualInterfaceEnumTimeoutMS{ DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; - if (SUCCEEDED(wil::reg::get_value_dword_nothrow(HKEY_LOCAL_MACHINE, MIDI_ROOT_REG_KEY, KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE, &individualInterfaceEnumTimeoutMS))) - { - individualInterfaceEnumTimeoutMS = max(individualInterfaceEnumTimeoutMS, KSA_INTERFACE_ENUM_TIMEOUT_MS_MINIMUM_VALUE); - individualInterfaceEnumTimeoutMS = min(individualInterfaceEnumTimeoutMS, KSA_INTERFACE_ENUM_TIMEOUT_MS_MAXIMUM_VALUE); - - m_individualInterfaceEnumTimeoutMS = individualInterfaceEnumTimeoutMS; - } - else - { - m_individualInterfaceEnumTimeoutMS = DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS; - } - - - // the ksa2603 fix enumerates device interfaces instead of parent devices - - winrt::hstring deviceInterfaceSelector( - L"System.Devices.InterfaceClassGuid:=\"{6994AD04-93EF-11D0-A3CC-00A0C9223196}\" AND " \ - L"System.Devices.InterfaceEnabled: = System.StructuredQueryType.Boolean#True"); - - auto additionalProps = winrt::single_threaded_vector(); - additionalProps.Append(L"System.Devices.Parent"); - - m_watcher = DeviceInformation::CreateWatcher(deviceInterfaceSelector); - - auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded); - auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceRemoved); - auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceUpdated); - - auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceStopped); - auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnEnumerationCompleted); + winrt::hstring parentDeviceSelector( + L"System.Devices.ClassGuid:=\"{4d36e96c-e325-11ce-bfc1-08002be10318}\" AND " \ + L"System.Devices.Present:=System.StructuredQueryType.Boolean#True"); - m_DeviceAdded = m_watcher.Added(winrt::auto_revoke, deviceAddedHandler); - m_DeviceRemoved = m_watcher.Removed(winrt::auto_revoke, deviceRemovedHandler); - m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler); - m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler); - m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler); + // :=System.StructuredQueryType.Boolean#True - // worker thread to handle endpoint creation, since we're enumerating interfaces now and need to aggregate them - m_endpointCreationThreadWakeup.create(wil::EventOptions::ManualReset); - std::jthread endpointCreationWorkerThread(std::bind_front(&CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker, this)); - m_endpointCreationThread = std::move(endpointCreationWorkerThread); - } - else - { - winrt::hstring parentDeviceSelector( - L"System.Devices.ClassGuid:=\"{4d36e96c-e325-11ce-bfc1-08002be10318}\" AND " \ - L"System.Devices.Present:=System.StructuredQueryType.Boolean#True"); - - // :=System.StructuredQueryType.Boolean#True + auto additionalProps = winrt::single_threaded_vector(); - auto additionalProps = winrt::single_threaded_vector(); + additionalProps.Append(L"System.Devices.DeviceManufacturer"); + additionalProps.Append(L"System.Devices.Manufacturer"); + additionalProps.Append(L"System.Devices.Parent"); - additionalProps.Append(L"System.Devices.DeviceManufacturer"); - additionalProps.Append(L"System.Devices.Manufacturer"); - additionalProps.Append(L"System.Devices.Parent"); + m_watcher = DeviceInformation::CreateWatcher(parentDeviceSelector, additionalProps, DeviceInformationKind::Device); - m_watcher = DeviceInformation::CreateWatcher(parentDeviceSelector, additionalProps, DeviceInformationKind::Device); - - auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded); - auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved); - auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceUpdated); - auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceStopped); - auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnEnumerationCompleted); - - m_DeviceAdded = m_watcher.Added(winrt::auto_revoke, deviceAddedHandler); - m_DeviceRemoved = m_watcher.Removed(winrt::auto_revoke, deviceRemovedHandler); - m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler); - m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler); - m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler); - } + auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded); + auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved); + auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceUpdated); + auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnDeviceStopped); + auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager::OnEnumerationCompleted); + m_DeviceAdded = m_watcher.Added(winrt::auto_revoke, deviceAddedHandler); + m_DeviceRemoved = m_watcher.Removed(winrt::auto_revoke, deviceRemovedHandler); + m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler); + m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler); + m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler); m_watcher.Start(); // Wait for everything to be created so that they're available immediately after service start. m_EnumerationCompleted.wait(INITIAL_ENUMERATION_TIMEOUT_MS); - if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) - { - if (m_pendingEndpointDefinitions.size() > 0) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enumeration completed with endpoint definitions left pending.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingUInt32(static_cast(m_pendingEndpointDefinitions.size()), "count pending definitions") - ); - } - } - return S_OK; } @@ -598,1715 +528,213 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint( } -_Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::UpdateNameTableWithCustomProperties( - std::shared_ptr masterEndpointDefinition, - std::shared_ptr customProperties) +GetPinName(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ std::wstring& pinName) { - RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); + std::unique_ptr pinNameData; + ULONG pinNameDataSize{ 0 }; + + auto pinNameHR = PinPropertyAllocate( + hFilter, + pinIndex, + KSPROPSETID_Pin, + KSPROPERTY_PIN_NAME, + (PVOID*)&pinNameData, + &pinNameDataSize + ); - for (auto const& pinEntry : masterEndpointDefinition->MidiPins) + if (SUCCEEDED(pinNameHR) || pinNameHR == HRESULT_FROM_WIN32(ERROR_SET_NOT_FOUND)) { - if (customProperties != nullptr && - (customProperties->Midi1Destinations.size() > 0 || customProperties->Midi1Sources.size() > 0)) + // Check to see if the pin has an iJack name + if (pinNameDataSize > 0) { - if (pinEntry.PinDataFlow == MidiFlow::MidiFlowIn) - { - // message destination (output port), pin flow is In - if (auto customConfiguredName = customProperties->Midi1Destinations.find(pinEntry.GroupIndex); - customConfiguredName != customProperties->Midi1Destinations.end()) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Found custom name for a Midi 1 destination.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), - TraceLoggingWideString(customConfiguredName->second.Name.c_str(), "custom name"), - TraceLoggingUInt8(pinEntry.GroupIndex, "group index") - ); - - masterEndpointDefinition->EndpointNameTable.UpdateDestinationEntryCustomName(pinEntry.GroupIndex, customConfiguredName->second.Name); - } - } - else if (pinEntry.PinDataFlow == MidiFlow::MidiFlowOut) - { - // message source (input port), pin flow is Out - if (auto customConfiguredName = customProperties->Midi1Sources.find(pinEntry.GroupIndex); - customConfiguredName != customProperties->Midi1Sources.end()) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Found custom name for a Midi 1 source.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), - TraceLoggingWideString(customConfiguredName->second.Name.c_str(), "custom name"), - TraceLoggingUInt8(pinEntry.GroupIndex, "group index") - ); - - masterEndpointDefinition->EndpointNameTable.UpdateSourceEntryCustomName(pinEntry.GroupIndex, customConfiguredName->second.Name); - } - } + pinName = pinNameData.get(); + + return S_OK; } - } - - - return S_OK; + return E_FAIL; } - -_Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::BuildPinsAndGroupTerminalBlocksPropertyData( - std::shared_ptr masterEndpointDefinition, - std::vector& pinMapPropertyData, - std::vector& groupTerminalBlocks) +GetPinDataFlow(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ KSPIN_DATAFLOW& dataFlow) { - uint8_t currentBlockNumber{ 0 }; - std::vector pinMapEntries{ }; + auto dataFlowHR = PinPropertySimple( + hFilter, + pinIndex, + KSPROPSETID_Pin, + KSPROPERTY_PIN_DATAFLOW, + &dataFlow, + sizeof(KSPIN_DATAFLOW) + ); - for (auto const& pin : masterEndpointDefinition->MidiPins) + if (SUCCEEDED(dataFlowHR)) { - RETURN_HR_IF(E_INVALIDARG, pin.FilterDeviceId.empty()); - - internal::GroupTerminalBlockInternal gtb; - - gtb.Number = ++currentBlockNumber; - gtb.GroupCount = 1; // always a single group for aggregate MIDI 1.0 devices - - PinMapEntryStagingEntry pinMapEntry{ }; - - pinMapEntry.PinId = pin.PinNumber; - pinMapEntry.FilterId = pin.FilterDeviceId; - pinMapEntry.PinDataFlow = pin.PinDataFlow; - - //MidiFlow flowFromUserPerspective; - - pinMapEntry.GroupIndex = pin.GroupIndex; - gtb.FirstGroupIndex = pin.GroupIndex; - - if (pin.PinDataFlow == MidiFlow::MidiFlowIn) // pin in, so user out : A MIDI Destination - { - gtb.Direction = MIDI_GROUP_TERMINAL_BLOCK_INPUT; // from the pin/gtb's perspective - - auto nameTableEntry = masterEndpointDefinition->EndpointNameTable.GetDestinationEntry(gtb.FirstGroupIndex); - if (nameTableEntry != nullptr && nameTableEntry->NewStyleName[0] != static_cast(0)) - { - gtb.Name = internal::TrimmedWStringCopy(nameTableEntry->NewStyleName); - } - } - else if (pin.PinDataFlow == MidiFlow::MidiFlowOut) // pin out, so user in : A MIDI Source - { - gtb.Direction = MIDI_GROUP_TERMINAL_BLOCK_OUTPUT; // from the pin/gtb's perspective - auto nameTableEntry = masterEndpointDefinition->EndpointNameTable.GetSourceEntry(gtb.FirstGroupIndex); - if (nameTableEntry != nullptr && nameTableEntry->NewStyleName[0] != static_cast(0)) - { - gtb.Name = internal::TrimmedWStringCopy(nameTableEntry->NewStyleName); - } - } - else - { - RETURN_IF_FAILED(E_INVALIDARG); - } - - // name fallback - if (gtb.Name.empty()) - { - gtb.Name = masterEndpointDefinition->EndpointName; - - if (gtb.FirstGroupIndex > 0) - { - gtb.Name += L" " + std::wstring{ gtb.FirstGroupIndex }; - } - } - - // default values as defined in the MIDI 2.0 USB spec - gtb.Protocol = 0x01; // midi 1.0 - gtb.MaxInputBandwidth = 0x0001; // 31.25 kbps - gtb.MaxOutputBandwidth = 0x0001; // 31.25 kbps - - groupTerminalBlocks.push_back(gtb); - pinMapEntries.push_back(pinMapEntry); + return S_OK; } - // Write Pin Map Property - // ===================================================== + return E_FAIL; +} + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager::GetKSDriverSuppliedName(HANDLE hInstantiatedFilter, std::wstring& name) +{ TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), MIDI_TRACE_EVENT_INFO, TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Building pin map property", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD) ); - // build the pin map property value - KSAGGMIDI_PIN_MAP_PROPERTY_VALUE pinMap{ }; + // get the name GUID - size_t totalStringSizesIncludingNulls{ 0 }; - for (auto const& entry : pinMapEntries) - { - totalStringSizesIncludingNulls += ((entry.FilterId.length() + 1) * sizeof(wchar_t)); - } + KSCOMPONENTID componentId{}; + KSPROPERTY prop{}; + ULONG countBytesReturned{}; - size_t totalMemoryBytes{ - SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_VALUE_HEADER + - SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_ENTRY_WITHOUT_STRING * pinMapEntries.size() + - totalStringSizesIncludingNulls }; + prop.Id = KSPROPERTY_GENERAL_COMPONENTID; + prop.Set = KSPROPSETID_General; + prop.Flags = KSPROPERTY_TYPE_GET; - pinMapPropertyData.resize(totalMemoryBytes); - auto currentPos = pinMapPropertyData.data(); + auto hrComponent = SyncIoctl( + hInstantiatedFilter, + IOCTL_KS_PROPERTY, + &prop, + sizeof(KSPROPERTY), + &componentId, + sizeof(KSCOMPONENTID), + &countBytesReturned + ); - // header - auto pinMapHeader = (PKSAGGMIDI_PIN_MAP_PROPERTY_VALUE)currentPos; - pinMapHeader->TotalByteCount = (UINT32)totalMemoryBytes; - currentPos += SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_VALUE_HEADER; + RETURN_IF_FAILED(hrComponent); - for (auto const& entry : pinMapEntries) + componentId.Name; // this is the GUID which points to the registry location with the driver-supplied name + + if (componentId.Name != GUID_NULL) { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Processing Pin Map entry", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), - TraceLoggingUInt32(entry.PinId, "Pin Id"), - TraceLoggingWideString(entry.FilterId.c_str(), "Filter Id") - ); + // we have the GUID where this name is stored, so get the driver-supplied name from the registry - PKSAGGMIDI_PIN_MAP_PROPERTY_ENTRY propEntry = (PKSAGGMIDI_PIN_MAP_PROPERTY_ENTRY)currentPos; + WCHAR nameFromRegistry[MAX_PATH]{ 0 }; // this should only be MAXPNAMELEN, but if someone tampered with it, could be larger, hence MAX_PATH - propEntry->ByteCount = (UINT)(SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_ENTRY_WITHOUT_STRING + ((entry.FilterId.length() + 1) * sizeof(wchar_t))); - propEntry->GroupIndex = entry.GroupIndex; - propEntry->PinDataFlow = entry.PinDataFlow; - propEntry->PinId = entry.PinId; + std::wstring regKey = L"SYSTEM\\CurrentControlSet\\Control\\MediaCategories\\" + internal::GuidToString(componentId.Name); - if (!entry.FilterId.empty()) + if (SUCCEEDED(wil::reg::get_value_string_nothrow(HKEY_LOCAL_MACHINE, regKey.c_str(), L"Name", nameFromRegistry))) { - wcscpy_s((wchar_t*)propEntry->FilterId, entry.FilterId.length() + 1, entry.FilterId.c_str()); + name = nameFromRegistry; } - currentPos += propEntry->ByteCount; + return S_OK; } - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"All pin map entries copied to property memory", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") - ); + return E_NOTFOUND; +} - return S_OK; -} + +#define KS_CATEGORY_AUDIO_GUID L"{6994AD04-93EF-11D0-A3CC-00A0C9223196}" -_Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpointV2( - std::shared_ptr masterEndpointDefinition -) +ParseParentIdIntoVidPidSerial( + _In_ winrt::hstring systemDevicesParentValue, + _In_ KsAggregateEndpointDefinition& endpointDefinition) { - RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") - ); - - DEVPROP_BOOLEAN devPropTrue = DEVPROP_TRUE; - - // we require at least one valid pin - RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition->MidiPins.size() < 1); - - std::vector interfaceDevProperties; - - MIDIENDPOINTCOMMONPROPERTIES commonProperties{}; - commonProperties.TransportId = TRANSPORT_LAYER_GUID; - commonProperties.EndpointDeviceType = MidiEndpointDeviceType_Normal; - commonProperties.FriendlyName = masterEndpointDefinition->EndpointName.c_str(); - commonProperties.TransportCode = TRANSPORT_CODE; - commonProperties.EndpointName = masterEndpointDefinition->EndpointName.c_str(); - commonProperties.EndpointDescription = nullptr; - commonProperties.CustomEndpointName = nullptr; - commonProperties.CustomEndpointDescription = nullptr; - commonProperties.UniqueIdentifier = masterEndpointDefinition->SerialNumber.empty() ? nullptr : masterEndpointDefinition->SerialNumber.c_str(); - commonProperties.ManufacturerName = masterEndpointDefinition->ManufacturerName.empty() ? nullptr : masterEndpointDefinition->ManufacturerName.c_str(); - commonProperties.SupportedDataFormats = MidiDataFormats::MidiDataFormats_UMP; - commonProperties.NativeDataFormat = MidiDataFormats_ByteStream; - - UINT32 capabilities{ 0 }; - capabilities |= MidiEndpointCapabilities_SupportsMultiClient; - capabilities |= MidiEndpointCapabilities_GenerateIncomingTimestamps; - capabilities |= MidiEndpointCapabilities_SupportsMidi1Protocol; - commonProperties.Capabilities = (MidiEndpointCapabilities)capabilities; - - - std::vector pinMapPropertyData; - std::vector groupTerminalBlocks{ }; - std::vector nameTablePropertyData; - - RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData( - masterEndpointDefinition, - pinMapPropertyData, - groupTerminalBlocks)); - - - interfaceDevProperties.push_back({ { DEVPKEY_KsAggMidiGroupPinMap, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_BINARY, static_cast(pinMapPropertyData.size()), pinMapPropertyData.data() }); - - - // Write Group Terminal Block Property - // ===================================================== - - std::vector groupTerminalBlockPropertyData{}; - - if (internal::WriteGroupTerminalBlocksToPropertyDataPointer(groupTerminalBlocks, groupTerminalBlockPropertyData)) - { - interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_BINARY, static_cast(groupTerminalBlockPropertyData.size()), - (PVOID)groupTerminalBlockPropertyData.data() }); - } - else - { - // write empty data - } - - - // Fold in custom properties, including MIDI 1 port names and naming approach - // =============================================================================== - - WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{}; - matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->EndpointDeviceInstanceId); - matchCriteria.UsbVendorId = masterEndpointDefinition->VID; - matchCriteria.UsbProductId = masterEndpointDefinition->PID; - matchCriteria.UsbSerialNumber = masterEndpointDefinition->SerialNumber; - matchCriteria.TransportSuppliedEndpointName = masterEndpointDefinition->EndpointName; - - auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria); - - // rebuild the name table, using the custom properties if present - RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); - - std::wstring customName{ }; - std::wstring customDescription{ }; - if (customProperties != nullptr) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), - TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"), - TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count") - ); - - if (!customProperties->Name.empty()) - { - customName = customProperties->Name; - commonProperties.CustomEndpointName = customName.c_str(); - } - - if (!customProperties->Description.empty()) - { - customDescription = customProperties->Description; - commonProperties.CustomEndpointDescription = customDescription.c_str(); - } - - // this includes image, the Midi 1 naming approach, etc. - customProperties->WriteNonCommonProperties(interfaceDevProperties); - } - else - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) - ); - } - - // Write Name table property, folding in the custom names we discovered earlier - // =============================================================================================== - RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); - masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); - - - // Write USB VID/PID Data - // ===================================================== - - if (masterEndpointDefinition->VID > 0) - { - interfaceDevProperties.push_back({ { PKEY_MIDI_UsbVID, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_UINT16, static_cast(sizeof(UINT16)), (PVOID)&masterEndpointDefinition->VID }); - } - else - { - interfaceDevProperties.push_back({ { PKEY_MIDI_UsbVID, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_EMPTY, 0, nullptr }); - } - - if (masterEndpointDefinition->PID > 0) - { - interfaceDevProperties.push_back({ { PKEY_MIDI_UsbPID, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_UINT16, static_cast(sizeof(UINT16)), (PVOID)&masterEndpointDefinition->PID }); - } - else - { - interfaceDevProperties.push_back({ { PKEY_MIDI_UsbPID, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_EMPTY, 0, nullptr }); - } - - - // Despite being a MIDI 1 device, we present as a UMP endpoint, so we need to set - // this property so the service can create the MIDI 1 ports without waiting for - // function blocks/discovery to complete or timeout (which it never will) - interfaceDevProperties.push_back({ { PKEY_MIDI_EndpointDiscoveryProcessComplete, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_BOOLEAN, (ULONG)sizeof(devPropTrue), (PVOID)&devPropTrue }); - - SW_DEVICE_CREATE_INFO createInfo{ }; - - createInfo.cbSize = sizeof(createInfo); - createInfo.pszInstanceId = masterEndpointDefinition->EndpointDeviceInstanceId.c_str(); - createInfo.CapabilityFlags = SWDeviceCapabilitiesNone; - createInfo.pszDeviceDescription = masterEndpointDefinition->EndpointName.c_str(); - - // Call the device manager and finish the creation - - HRESULT swdCreationResult; - wil::unique_cotaskmem_string newDeviceInterfaceId; - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Activating endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") - ); - - // set to true if we only want to create UMP endpoints - bool umpOnly = false; - - LOG_IF_FAILED( - swdCreationResult = m_midiDeviceManager->ActivateEndpoint( - masterEndpointDefinition->ParentDeviceInstanceId.c_str(), - umpOnly, - MidiFlow::MidiFlowBidirectional, // bidi only for the UMP endpoint - &commonProperties, - (ULONG)interfaceDevProperties.size(), - (ULONG)0, - interfaceDevProperties.data(), - nullptr, - &createInfo, - &newDeviceInterfaceId) - ); - - if (SUCCEEDED(swdCreationResult)) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Aggregate UMP endpoint created", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), - TraceLoggingWideString(newDeviceInterfaceId.get(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) - ); - - // return new device interface id - masterEndpointDefinition->EndpointDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(std::wstring{ newDeviceInterfaceId.get() }); - - auto lock = m_availableEndpointDefinitionsLock.lock(); - - // Add to internal endpoint manager - m_availableEndpointDefinitionsV2.insert_or_assign( - internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->ParentDeviceInstanceId), - masterEndpointDefinition); - - return swdCreationResult; - } - else - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_ERROR, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_ERROR), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Aggregate UMP endpoint creation failed", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), - TraceLoggingHResult(swdCreationResult, MIDI_TRACE_EVENT_HRESULT_FIELD) - ); - - return swdCreationResult; - } -} - - - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::UpdateExistingMidiUmpEndpointWithFilterChanges( - std::shared_ptr masterEndpointDefinition -) -{ - RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") - ); - - // we require at least one valid pin - RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition->MidiPins.size() < 1); - - std::vector interfaceDevProperties{ }; - - std::vector pinMapPropertyData; - std::vector groupTerminalBlocks{ }; - std::vector nameTablePropertyData; - - // update the pin map to have all the existing pins - // plus the new pins. Update Group Terminal Blocks at the same time. - RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData( - masterEndpointDefinition, - pinMapPropertyData, - groupTerminalBlocks)); - - - // Write Pin Map Property - // ===================================================== - interfaceDevProperties.push_back({ { DEVPKEY_KsAggMidiGroupPinMap, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_BINARY, static_cast(pinMapPropertyData.size()), pinMapPropertyData.data() }); - - - // Write Group Terminal Block Property - // ===================================================== - - std::vector groupTerminalBlockData; - - if (internal::WriteGroupTerminalBlocksToPropertyDataPointer(groupTerminalBlocks, groupTerminalBlockData)) - { - interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr }, - DEVPROP_TYPE_BINARY, (ULONG)groupTerminalBlockData.size(), (PVOID)groupTerminalBlockData.data() }); - } - else - { - // write empty data - } - - - // Fold in custom properties, including MIDI 1 port names and naming approach - // =============================================================================== - - WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{}; - matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->EndpointDeviceInstanceId); - matchCriteria.UsbVendorId = masterEndpointDefinition->VID; - matchCriteria.UsbProductId = masterEndpointDefinition->PID; - matchCriteria.UsbSerialNumber = masterEndpointDefinition->SerialNumber; - matchCriteria.TransportSuppliedEndpointName = masterEndpointDefinition->EndpointName; - - auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria); - - std::wstring customName{ }; - std::wstring customDescription{ }; - if (customProperties != nullptr) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), - TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"), - TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count") - ); - - // this includes image, the Midi 1 naming approach, etc. - customProperties->WriteNonCommonProperties(interfaceDevProperties); - } - else - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) - ); - } - - // store the property data for the name table - masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); - - - // Write Name table property, folding in the custom names we discovered earlier - // =============================================================================================== - RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); - masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); - - HRESULT updateResult{}; - - LOG_IF_FAILED(updateResult = m_midiDeviceManager->UpdateEndpointProperties( - masterEndpointDefinition->EndpointDeviceId.c_str(), - static_cast(interfaceDevProperties.size()), - interfaceDevProperties.data() - )); - - - if (SUCCEEDED(updateResult)) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Aggregate UMP endpoint updated with new filter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceId.c_str(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) - ); - - auto lock = m_availableEndpointDefinitionsLock.lock(); - - // Add to internal endpoint manager - m_availableEndpointDefinitionsV2.insert_or_assign( - internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->ParentDeviceInstanceId), - masterEndpointDefinition); - - } - else - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_ERROR, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_ERROR), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Aggregate UMP endpoint update failed", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), - TraceLoggingHResult(updateResult, MIDI_TRACE_EVENT_HRESULT_FIELD) - ); - } - - return updateResult; -} - - -HRESULT -GetPinName(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ std::wstring& pinName) -{ - std::unique_ptr pinNameData; - ULONG pinNameDataSize{ 0 }; - - auto pinNameHR = PinPropertyAllocate( - hFilter, - pinIndex, - KSPROPSETID_Pin, - KSPROPERTY_PIN_NAME, - (PVOID*)&pinNameData, - &pinNameDataSize - ); - - if (SUCCEEDED(pinNameHR) || pinNameHR == HRESULT_FROM_WIN32(ERROR_SET_NOT_FOUND)) - { - // Check to see if the pin has an iJack name - if (pinNameDataSize > 0) - { - pinName = pinNameData.get(); - - return S_OK; - } - } - - return E_FAIL; -} - -HRESULT -GetPinDataFlow(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ KSPIN_DATAFLOW& dataFlow) -{ - auto dataFlowHR = PinPropertySimple( - hFilter, - pinIndex, - KSPROPSETID_Pin, - KSPROPERTY_PIN_DATAFLOW, - &dataFlow, - sizeof(KSPIN_DATAFLOW) - ); - - if (SUCCEEDED(dataFlowHR)) - { - return S_OK; - } - - return E_FAIL; -} - - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::GetKSDriverSuppliedName(HANDLE hInstantiatedFilter, std::wstring& name) -{ - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - // get the name GUID - - KSCOMPONENTID componentId{}; - KSPROPERTY prop{}; - ULONG countBytesReturned{}; - - prop.Id = KSPROPERTY_GENERAL_COMPONENTID; - prop.Set = KSPROPSETID_General; - prop.Flags = KSPROPERTY_TYPE_GET; - - auto hrComponent = SyncIoctl( - hInstantiatedFilter, - IOCTL_KS_PROPERTY, - &prop, - sizeof(KSPROPERTY), - &componentId, - sizeof(KSCOMPONENTID), - &countBytesReturned - ); - - if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) - { - // changed to not log the failure here. Failures are expected for many devices, and it's adding noise to error logs - if (FAILED(hrComponent)) - { - return hrComponent; - } - } - else - { - RETURN_IF_FAILED(hrComponent); - } - - componentId.Name; // this is the GUID which points to the registry location with the driver-supplied name - - if (componentId.Name != GUID_NULL) - { - // we have the GUID where this name is stored, so get the driver-supplied name from the registry - - WCHAR nameFromRegistry[MAX_PATH]{ 0 }; // this should only be MAXPNAMELEN, but if someone tampered with it, could be larger, hence MAX_PATH - - std::wstring regKey = L"SYSTEM\\CurrentControlSet\\Control\\MediaCategories\\" + internal::GuidToString(componentId.Name); - - if (SUCCEEDED(wil::reg::get_value_string_nothrow(HKEY_LOCAL_MACHINE, regKey.c_str(), L"Name", nameFromRegistry))) - { - name = nameFromRegistry; - } - - return S_OK; - } - - return E_NOTFOUND; -} - - - -#define KS_CATEGORY_AUDIO_GUID L"{6994AD04-93EF-11D0-A3CC-00A0C9223196}" - - - -HRESULT -ParseParentIdIntoVidPidSerial( - _In_ winrt::hstring systemDevicesParentValue, - _In_ KsAggregateEndpointDefinition& endpointDefinition) -{ - - if (systemDevicesParentValue.empty()) - { - RETURN_IF_FAILED(E_INVALIDARG); - } - - - // Examples - // --------------------------------------------------------------------------- - // Parent values with serial: USB\VID_16C0&PID_05E4\ContinuuMini_SN024066 - // USB\VID_2573&PID_008A\no_serial_number (yes, this is the iSerialNumber in USB :/ - // 0x03 iSerialNumber "no serial number" - // USB\VID_12E6&PID_002C\251959d4f21fc283 - // - // Parent values without serial: USB\VID_F055&PID_0069\8&2858bbac&0&4 - // USB\VID_2662&PID_000D\8&24eb0394&0&4 - // USB\VID_05E3&PID_0610\7&2f028fc9&0&4 - // ROOT\MOTUBUS\0000 - - std::wstring parentVal = systemDevicesParentValue.c_str(); - - std::wstringstream ss(parentVal); - std::wstring usbSection{}; - - std::getline(ss, usbSection, static_cast('\\')); - - if (usbSection == L"USB") - { - // get the VID/PID section - - std::wstring vidPidSection{}; - - std::getline(ss, vidPidSection, static_cast('\\')); - - if (!vidPidSection.empty()) - { - std::wstring serialSection{}; - std::getline(ss, serialSection, static_cast('\\')); - - std::wstring vidPidString1{}; - std::wstring vidPidString2{}; - - std::wstringstream ssVidPid(vidPidSection); - std::getline(ssVidPid, vidPidString1, static_cast('&')); - std::getline(ssVidPid, vidPidString2, static_cast('&')); - - wchar_t* end{ nullptr }; - - // find the VID - if (vidPidString1.starts_with(L"VID_")) - { - endpointDefinition.VID = static_cast(wcstol(vidPidString1.substr(4).c_str(), &end, 16)); - } - else if (vidPidString2.starts_with(L"VID_")) - { - endpointDefinition.VID = static_cast(wcstol(vidPidString2.substr(4).c_str(), &end, 16)); - } - - // find the PID - if (vidPidString1.starts_with(L"PID_")) - { - endpointDefinition.PID = static_cast(wcstol(vidPidString1.substr(4).c_str(), &end, 16)); - } - else if (vidPidString2.starts_with(L"PID_")) - { - endpointDefinition.PID = static_cast(wcstol(vidPidString2.substr(4).c_str(), &end, 16)); - } - - // serial numbers with a & in them, are generated by our system - // it's possible a vendor may have a serial number with this in it, - // but in that case, we just ditch it. - if (serialSection.find_first_of('&') == serialSection.npos) - { - // Windows replaces spaces in the serial number with the underscore. - // yes, this will end up catching the few (if any) serials that do - // actually include an underscore. However, there are a bunch with spaces. - std::replace(serialSection.begin(), serialSection.end(), '_', ' '); - endpointDefinition.SerialNumber = serialSection; - } - } - } - else - { - // not a USB device, or otherwise uses a custom driver. We can't count - // on being able to parse the parent id. Example: MOTU has - // ROOT\MOTUBUS\0000 as the parent - } - - return S_OK; -} - - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::FindActivatedEndpointDefinitionForFilterDevice( - std::wstring parentDeviceInstanceId, - std::shared_ptr& endpointDefinition -) -{ - for (auto const& entry : m_availableEndpointDefinitionsV2) - { - if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->ParentDeviceInstanceId) == - internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) - { - endpointDefinition = entry.second; - - return S_OK; - } - } - - return E_NOTFOUND; -} - - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::FindOrCreatePendingEndpointDefinitionForFilterDevice( - DeviceInformation filterDevice, - std::shared_ptr& endpointDefinition -) -{ - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - // we require that the System.Devices.DeviceInstanceId property was requested for the passed-in filter device - auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); - RETURN_HR_IF(E_FAIL, deviceInstanceId.empty()); - - auto additionalProperties = winrt::single_threaded_vector(); - additionalProperties.Append(L"System.Devices.DeviceManufacturer"); - additionalProperties.Append(L"System.Devices.Manufacturer"); - additionalProperties.Append(L"System.Devices.Parent"); - - auto parentDevice = DeviceInformation::CreateFromIdAsync(deviceInstanceId, - additionalProperties, winrt::Windows::Devices::Enumeration::DeviceInformationKind::Device).get(); - - // See if we already have a pending master endpoint definition for this parent device - - auto lock = m_pendingEndpointDefinitionsLock.lock(); // we lock to avoid having one inserted while we're processing - - auto parentInstanceIdToFind = internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice.Id().c_str()); - auto it = std::find_if( - m_pendingEndpointDefinitions.begin(), - m_pendingEndpointDefinitions.end(), - [&parentInstanceIdToFind](const std::shared_ptr def){return internal::NormalizeDeviceInstanceIdWStringCopy(def->ParentDeviceInstanceId) == parentInstanceIdToFind; }); - - if (it != m_pendingEndpointDefinitions.end()) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Found existing aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(parentInstanceIdToFind.c_str(), "parent") - ); - - endpointDefinition = *it; - return S_OK; - } - - // We don't have one, create one and add, and get all the parent device information for it - // we still have the map locked, so keep this code fast - auto newEndpointDefinition = std::make_shared(); - RETURN_HR_IF_NULL(E_OUTOFMEMORY, newEndpointDefinition); - - newEndpointDefinition->ParentDeviceName = parentDevice.Name(); - newEndpointDefinition->EndpointName = parentDevice.Name(); - newEndpointDefinition->ParentDeviceInstanceId = parentDevice.Id(); - - LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(newEndpointDefinition->ParentDeviceInstanceId.c_str(), *newEndpointDefinition)); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Creating new aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(newEndpointDefinition->ParentDeviceInstanceId.c_str(), "parent") - ); - - // only some vendor drivers provide an actual manufacturer - // and all the in-box drivers just provide the Generic USB Audio string - // TODO: Is "Generic USB Audio" a string that is localized? If so, this - // code will not have the intended effect outside of en-US - auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L""); - if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft") - { - newEndpointDefinition->ManufacturerName = manufacturer; - } - - // default hash is the device id. - std::hash hasher; - std::wstring hash; - hash = std::to_wstring(hasher(newEndpointDefinition->ParentDeviceInstanceId)); - - newEndpointDefinition->EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Adding pending aggregate UMP endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - m_pendingEndpointDefinitions.push_back(newEndpointDefinition); - endpointDefinition = newEndpointDefinition; - - - return S_OK; -} - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::IncrementAndGetNextGroupIndex( - std::shared_ptr definition, - MidiFlow dataFlowFromUserPerspective, - uint8_t& groupIndex) -{ - if (dataFlowFromUserPerspective == MidiFlow::MidiFlowIn) - { - definition->CurrentHighestMidiSourceGroupIndex++; - groupIndex = definition->CurrentHighestMidiSourceGroupIndex; - } - else - { - definition->CurrentHighestMidiDestinationGroupIndex++; - groupIndex = definition->CurrentHighestMidiDestinationGroupIndex; - } - - return S_OK; -} - -#define MAX_THREAD_WORKER_WAIT_TIME_MS 20000 - -_Use_decl_annotations_ -void CMidi2KSAggregateMidiEndpointManager::EndpointCreationThreadWorker( - std::stop_token token) -{ - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - while (!token.stop_requested()) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Waiting to be woken up.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - // wait to be woken up - m_endpointCreationThreadWakeup.wait(MAX_THREAD_WORKER_WAIT_TIME_MS); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: I'm awake now.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - if (!token.stop_requested()) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Short nap before checking.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingUInt32(m_individualInterfaceEnumTimeoutMS, "nap period (ms)") - ); - - // we sleep for this timeout before we check to see if the thread is signaled. - // this gives time for an additional interface notification to cause the event - // to be reset - Sleep(m_individualInterfaceEnumTimeoutMS); - - // if we're still signaled, that means no other pnp notifications came in during the nap, or if they - // did, they completed within that nap period - if (m_endpointCreationThreadWakeup.is_signaled() && m_pendingEndpointDefinitions.size() > 0) - { - m_endpointCreationThreadWakeup.ResetEvent(); // it's a manual reset event - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Thread was signaled and pending definition count > 0. Proceed to processing queue.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - - // lock the definitions so we can process them - auto lock = m_pendingEndpointDefinitionsLock.lock(); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Processing pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - while (m_pendingEndpointDefinitions.size() > 0) - { - // effectively a queue, but because we have to iterate and search - // this in other functions, a vector is more appropriate - auto ep = m_pendingEndpointDefinitions[0]; - m_pendingEndpointDefinitions.erase(m_pendingEndpointDefinitions.begin()); - - // create the endpoint - LOG_IF_FAILED(CreateMidiUmpEndpointV2(ep)); - } - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Processed all pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - } - -#ifdef _DEBUG - else - { - if (m_pendingEndpointDefinitions.size() == 0) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Woken up, but no work to do. Pending count == 0.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - } - else - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Woken up, but thread is no longer signaled", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - } - } -#endif - } - } - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Exit", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - -} - -_Use_decl_annotations_ -bool CMidi2KSAggregateMidiEndpointManager::KSAEndpointForDeviceExists( - _In_ std::wstring parentDeviceInstanceId) -{ - for (auto const& entry : m_availableEndpointDefinitionsV2) - { - if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->ParentDeviceInstanceId) == - internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) - { - return true; - } - } - - return false; -} - - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::GetMidi1FilterPins( - DeviceInformation filterDevice, - std::vector& pinListToAddTo -) -{ - // Wrapper opens the handle internally. - KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); - RETURN_IF_FAILED(deviceHandleWrapper.Open()); - - // ============================================================================================= - // Go through all the enumerated pins, looking for a MIDI 1.0 pin - - // enumerate all the pins for this filter - ULONG cPins{ 0 }; - - RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return PinPropertySimple(h, 0, KSPROPSETID_Pin, KSPROPERTY_PIN_CTYPES, &cPins, sizeof(cPins)); - })); - - ULONG midiInputPinIndexForThisFilter{ 0 }; - ULONG midiOutputPinIndexForThisFilter{ 0 }; - - // process the pins for this filter. Not all will be MIDI pins - for (UINT pinIndex = 0; pinIndex < cPins; pinIndex++) - { - // Check the communication capabilities of the pin so we can fail fast - KSPIN_COMMUNICATION communication = (KSPIN_COMMUNICATION)0; - - RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return PinPropertySimple(h, pinIndex, KSPROPSETID_Pin, KSPROPERTY_PIN_COMMUNICATION, &communication, sizeof(KSPIN_COMMUNICATION)); - })); - - // The external connector pin representing the physical connection - // has KSPIN_COMMUNICATION_NONE. We can only create on software IO pins which - // have a valid communication. Skip connector pins. - if (communication == KSPIN_COMMUNICATION_NONE) - { - continue; - } - - // Duplicate the handle to safely pass it to another component or store it. - wil::unique_handle handleDupe(deviceHandleWrapper.GetHandle()); - RETURN_IF_NULL_ALLOC(handleDupe); - - // we try to open UMP only so we understand the device - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Checking for UMP pin. This will fallback error fail for non-UMP devices.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); - - KsHandleWrapper m_PinHandleWrapperUmp(filterDevice.Id().c_str(), pinIndex, MidiTransport_CyclicUMP, handleDupe.get()); - if (SUCCEEDED(m_PinHandleWrapperUmp.Open())) - { - // this is a UMP pin. The KS transport will handle it, so we skip it here. - // In the future, we may want to bail on the first UMP pin we find. - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Found UMP/MIDI2 pin. Skipping for this transport.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); - - continue; - } - - - // try to open as a MIDI 1 bytestream pin - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Checking for MIDI 1 pin. This will fallback error fail for non-MIDI devices.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); - - KsHandleWrapper pinHandleWrapperMidi1(filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get()); - if (SUCCEEDED(pinHandleWrapperMidi1.Open())) - { - // this is a MIDI 1.0 byte format pin, so let's process it - KsAggregateEndpointMidiPinDefinition pinDefinition{ }; - - pinDefinition.PinNumber = pinIndex; - pinDefinition.FilterDeviceId = std::wstring{ filterDevice.Id() }; - pinDefinition.FilterName = std::wstring{ filterDevice.Name() }; - - // find the name of the pin (supplied by iJack, and if not available, generated by the stack) - std::wstring pinName{ }; - HRESULT pinNameHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetPinName(h, pinIndex, pinName); - }); - - if (SUCCEEDED(pinNameHr)) - { - pinDefinition.PinName = pinName; - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Pin has name", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingWideString(pinDefinition.PinName.c_str(), "pin name") - ); - } - - // get the data flow so we know if this is a MIDI Input (Source) or a MIDI Output (Destination) - KSPIN_DATAFLOW dataFlow = (KSPIN_DATAFLOW)0; - RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetPinDataFlow(h, pinIndex, dataFlow); - })); - - if (dataFlow == KSPIN_DATAFLOW_IN) - { - // MIDI Out (input to device) - pinDefinition.PinDataFlow = MidiFlow::MidiFlowIn; - pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowOut; // opposite - pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter); - - midiOutputPinIndexForThisFilter++; - } - else if (dataFlow == KSPIN_DATAFLOW_OUT) - { - // MIDI In (output from device) - pinDefinition.PinDataFlow = MidiFlow::MidiFlowOut; - pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowIn; // opposite - pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter); - - midiInputPinIndexForThisFilter++; - } - - pinListToAddTo.push_back(pinDefinition); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"MIDI 1.0 pin added", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); - } - } - - - return S_OK; -} - - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::UpdateNewPinDefinitions( - std::wstring filterDeviceid, - std::wstring driverSuppliedName, - std::shared_ptr endpointDefinition) -{ - // At this point, we need to have *all* the pins for the endpoint, not just this filter - for (auto& pinDefinition : endpointDefinition->MidiPins) - { - if (internal::NormalizeDeviceInstanceIdWStringCopy(pinDefinition.FilterDeviceId) != - internal::NormalizeDeviceInstanceIdWStringCopy(filterDeviceid)) - { - // only process the pins for this filter interface. We don't want to - // change anything that has already been built. But we do need the - // context of all pins when getting the group index. - continue; - } - - // Figure out the group index for the pin. This needs the context of the - // entire device. Failure to get the group index is fatal - RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDeviceid.c_str(), "filter device id"), - TraceLoggingUInt8(pinDefinition.GroupIndex, "group index"), - TraceLoggingWideString(pinDefinition.DataFlowFromUserPerspective == - MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") - ); - - std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update - - // Build the name table entry for this individual pin - endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( - pinDefinition.GroupIndex, - pinDefinition.DataFlowFromUserPerspective, - customName, - driverSuppliedName, - pinDefinition.FilterName, - pinDefinition.PinName, - pinDefinition.PortIndexWithinThisFilterAndDirection - ); - } - - return S_OK; -} - - - - -HRESULT -PopulatePinKSDataFormats(HANDLE filterHandle, Some_vector_of_pin_format_structs) -{ - //Try this, it should be a fairly easy thing to add to your change. - // retrieve the : - //KSPROPSETID_Pin, - // KSPROPERTY_PIN_DATARANGES, - - // limit to pins with(pKsDataFormat->MajorFormat == KSDATAFORMAT_TYPE_MUSIC) - // - // Retrieval is going to follow the same ksmultipleitemp pattern as KSPROPERTY_MIDI2_GROUP_TERMINAL_BLOCKS -} - - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceAdded( - DeviceWatcher /* watcher */, - DeviceInformation filterDevice -) -{ - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "added interface") - ); - - std::wstring transportCode(TRANSPORT_CODE); - - // Wrapper opens the handle internally. - KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); - RETURN_IF_FAILED(deviceHandleWrapper.Open()); - - std::shared_ptr endpointDefinition{ nullptr }; - - // get all the MIDI 1 pins. These are only partially processed because some things - // like group index and naming require the full context of all filters/pins on the - // parent device. - std::vector pinList{ }; - RETURN_IF_FAILED(GetMidi1FilterPins(filterDevice, pinList)); - - if (pinList.size() == 0) + if (systemDevicesParentValue.empty()) { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"No MIDI 1.0 pins for this filter device", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); - - return S_OK; + RETURN_IF_FAILED(E_INVALIDARG); } - auto parentInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); - RETURN_HR_IF(E_FAIL, parentInstanceId.empty()); - - // Driver-supplied name. This is needed for WinMM backwards compatibility - std::wstring driverSuppliedName{}; - - // Using lamba function to prevent handle from dissapearing when being used. - // we don't log the HRESULT because this is not critical and will often fail, - // adding unnecessary noise to the error logs - deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetKSDriverSuppliedName(h, driverSuppliedName); - }); - - - // check to see if we already have an *activated* endpoint for this filter - if (KSAEndpointForDeviceExists(parentInstanceId.c_str())) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"KSA endpoint for this filter already activated. Updating it.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") - ); - - std::shared_ptr existingActivatedEndpointDefinition { nullptr }; - - // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindActivatedMasterEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); - RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); - // add our new pins into the existing endpoint definition - existingActivatedEndpointDefinition->MidiPins.insert(existingActivatedEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); - RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), driverSuppliedName, existingActivatedEndpointDefinition)); + // Examples + // --------------------------------------------------------------------------- + // Parent values with serial: USB\VID_16C0&PID_05E4\ContinuuMini_SN024066 + // USB\VID_2573&PID_008A\no_serial_number (yes, this is the iSerialNumber in USB :/ + // 0x03 iSerialNumber "no serial number" + // USB\VID_12E6&PID_002C\251959d4f21fc283 + // + // Parent values without serial: USB\VID_F055&PID_0069\8&2858bbac&0&4 + // USB\VID_2662&PID_000D\8&24eb0394&0&4 + // USB\VID_05E3&PID_0610\7&2f028fc9&0&4 + // ROOT\MOTUBUS\0000 - RETURN_IF_FAILED(UpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition)); + std::wstring parentVal = systemDevicesParentValue.c_str(); - return S_OK; - } - else - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Endpoint for this device does not already exist. Proceed to creating new one.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") - ); - } + std::wstringstream ss(parentVal); + std::wstring usbSection{}; + std::getline(ss, usbSection, static_cast('\\')); - // if the endpointDefinition is null, that means we haven't found an existing - // activated endpoint definition we need to use, and so we proceed to check - // for an existing pending endpoint definition. If found, it's used. If not - // found, the function will create a new one for us to use, with all the - // endpoint-specific details (excluding pins) populated. - if (endpointDefinition == nullptr) + if (usbSection == L"USB") { - // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindOrCreatePendingMasterEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); - RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); - - // add our new pins into the existing endpoint definition - endpointDefinition->MidiPins.insert(endpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); - pinList.clear(); // just make sure we don't use this one, accidentally - } + // get the VID/PID section + std::wstring vidPidSection{}; -#ifdef _DEBUG - if (!driverSuppliedName.empty()) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Driver-supplied name found", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingWideString(driverSuppliedName.c_str(), "driver-supplied name") - ); - } -#endif + std::getline(ss, vidPidSection, static_cast('\\')); - // At this point, we need to have *all* the pins for the endpoint, not just this filter - for (auto& pinDefinition : endpointDefinition->MidiPins) - { - if (internal::NormalizeDeviceInstanceIdWStringCopy(pinDefinition.FilterDeviceId) != - internal::NormalizeDeviceInstanceIdWStringCopy(filterDevice.Id().c_str())) + if (!vidPidSection.empty()) { - // only process the pins for this filter interface. We don't want to - // change anything that has already been built. But we do need the - // context of all pins when getting the group index. - continue; - } - - // Figure out the group index for the pin. This needs the context of the - // entire device. Failure to get the group index is fatal - RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingUInt8(pinDefinition.GroupIndex, "group index"), - TraceLoggingWideString(pinDefinition.DataFlowFromUserPerspective == - MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") - ); - - std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update - - // Build the name table entry for this individual pin - endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( - pinDefinition.GroupIndex, - pinDefinition.DataFlowFromUserPerspective, - customName, - driverSuppliedName, - pinDefinition.FilterName, - pinDefinition.PinName, - pinDefinition.PortIndexWithinThisFilterAndDirection - ); - } - - // we have an endpoint definition - m_endpointCreationThreadWakeup.SetEvent(); - - return S_OK; -} - -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceRemoved( - DeviceWatcher watcher, - DeviceInformationUpdate deviceInterfaceUpdate -) -{ - UNREFERENCED_PARAMETER(watcher); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "removed interface") - ); - - std::wstring removedFilterDeviceId{ internal::NormalizeDeviceInstanceIdWStringCopy(deviceInterfaceUpdate.Id().c_str()) }; + std::wstring serialSection{}; + std::getline(ss, serialSection, static_cast('\\')); - // find an active device with this filter + std::wstring vidPidString1{}; + std::wstring vidPidString2{}; - std::shared_ptr endpointDefinition{ nullptr }; + std::wstringstream ssVidPid(vidPidSection); + std::getline(ssVidPid, vidPidString1, static_cast('&')); + std::getline(ssVidPid, vidPidString2, static_cast('&')); - + wchar_t* end{ nullptr }; - for (auto& endpointListIterator : m_availableEndpointDefinitionsV2) - { - // check pins for this filter - for (auto& pin: endpointListIterator.second->MidiPins) - { - if (internal::NormalizeDeviceInstanceIdWStringCopy(pin.FilterDeviceId) == removedFilterDeviceId) + // find the VID + if (vidPidString1.starts_with(L"VID_")) + { + endpointDefinition.VID = static_cast(wcstol(vidPidString1.substr(4).c_str(), &end, 16)); + } + else if (vidPidString2.starts_with(L"VID_")) { - endpointDefinition = endpointListIterator.second; - break; + endpointDefinition.VID = static_cast(wcstol(vidPidString2.substr(4).c_str(), &end, 16)); } - } - - } - if (endpointDefinition != nullptr) - { - bool done { false }; - - while (!done) - { - auto foundIt = std::find_if( - endpointDefinition->MidiPins.begin(), - endpointDefinition->MidiPins.end(), - [&removedFilterDeviceId](KsAggregateEndpointMidiPinDefinition& pin) { return internal::NormalizeDeviceInstanceIdWStringCopy(pin.FilterDeviceId) == removedFilterDeviceId; } - ); - - if (foundIt != endpointDefinition->MidiPins.end()) + // find the PID + if (vidPidString1.starts_with(L"PID_")) { - // erase the pin definition with this - endpointDefinition->MidiPins.erase(foundIt); + endpointDefinition.PID = static_cast(wcstol(vidPidString1.substr(4).c_str(), &end, 16)); } - else + else if (vidPidString2.starts_with(L"PID_")) { - // we've removed all the pins for this interface - done = true; + endpointDefinition.PID = static_cast(wcstol(vidPidString2.substr(4).c_str(), &end, 16)); } - } - - if (endpointDefinition->MidiPins.size() > 0) - { - // we've removed all the pins for this interface, but there are still - // pins left, so now it's time to update the endpoint - - - // TODO: Need to cache the name from the driver/registry so we don't have to do a lookup here. - - // update remaining pins in existing endpoint definition - RETURN_IF_FAILED(UpdateNewPinDefinitions(removedFilterDeviceId, L"", endpointDefinition)); - RETURN_IF_FAILED(UpdateExistingMidiUmpEndpointWithFilterChanges(endpointDefinition)); - } - else - { - auto lock = m_availableEndpointDefinitionsLock.lock(); - - // notify the device manager using the InstanceId for this midi device - RETURN_IF_FAILED(m_midiDeviceManager->RemoveEndpoint( - internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->EndpointDeviceInstanceId).c_str())); - // remove the endpoint from the list - - m_availableEndpointDefinitionsV2.erase(internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->ParentDeviceInstanceId)); + // serial numbers with a & in them, are generated by our system + // it's possible a vendor may have a serial number with this in it, + // but in that case, we just ditch it. + if (serialSection.find_first_of('&') == serialSection.npos) + { + // Windows replaces spaces in the serial number with the underscore. + // yes, this will end up catching the few (if any) serials that do + // actually include an underscore. However, there are a bunch with spaces. + std::replace(serialSection.begin(), serialSection.end(), '_', ' '); + endpointDefinition.SerialNumber = serialSection; + } } } + else + { + // not a USB device, or otherwise uses a custom driver. We can't count + // on being able to parse the parent id. Example: MOTU has + // ROOT\MOTUBUS\0000 as the parent + } return S_OK; } -_Use_decl_annotations_ -HRESULT -CMidi2KSAggregateMidiEndpointManager::OnFilterDeviceInterfaceUpdated( - DeviceWatcher watcher, - DeviceInformationUpdate deviceInterfaceUpdate -) -{ - UNREFERENCED_PARAMETER(watcher); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "updated interface") - ); - - // Flow for interface UPDATED - // - Check for any name changes - // - If any relevant changes recalculate GTBs and Name table as above and update properties - // - - - - return S_OK; -} - - - _Use_decl_annotations_ @@ -2643,6 +1071,9 @@ CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded( } + + + _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved(DeviceWatcher watcher, DeviceInformationUpdate device) @@ -2751,49 +1182,24 @@ winrt::hstring CMidi2KSAggregateMidiEndpointManager::FindMatchingInstantiatedEnd { criteria.Normalize(); - if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + for (auto const& def : m_availableEndpointDefinitions) { - for (auto const& def : m_availableEndpointDefinitionsV2) - { - WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; - available.DeviceInstanceId = def.second->EndpointDeviceInstanceId; - available.EndpointDeviceId = def.second->EndpointDeviceId; - available.UsbVendorId = def.second->VID; - available.UsbProductId = def.second->PID; - available.UsbSerialNumber = def.second->SerialNumber; - available.TransportSuppliedEndpointName = def.second->EndpointName; - available.DeviceManufacturerName = def.second->ManufacturerName; + available.DeviceInstanceId = def.second.EndpointDeviceInstanceId; + available.EndpointDeviceId = def.second.EndpointDeviceId; + available.UsbVendorId = def.second.VID; + available.UsbProductId = def.second.PID; + available.UsbSerialNumber = def.second.SerialNumber; + available.TransportSuppliedEndpointName = def.second.EndpointName; + available.DeviceManufacturerName = def.second.ManufacturerName; - if (available.Matches(criteria)) - { - return available.EndpointDeviceId; - } - } - } - else - { - for (auto const& def : m_availableEndpointDefinitions) + if (available.Matches(criteria)) { - WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; - - available.DeviceInstanceId = def.second.EndpointDeviceInstanceId; - available.EndpointDeviceId = def.second.EndpointDeviceId; - available.UsbVendorId = def.second.VID; - available.UsbProductId = def.second.PID; - available.UsbSerialNumber = def.second.SerialNumber; - available.TransportSuppliedEndpointName = def.second.EndpointName; - available.DeviceManufacturerName = def.second.ManufacturerName; - - if (available.Matches(criteria)) - { - return available.EndpointDeviceId; - } + return available.EndpointDeviceId; } } - - return L""; } @@ -2809,12 +1215,6 @@ CMidi2KSAggregateMidiEndpointManager::Shutdown() TraceLoggingPointer(this, "this") ); - if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) - { - m_endpointCreationThread.request_stop(); - m_endpointCreationThreadWakeup.SetEvent(); - } - m_DeviceAdded.revoke(); m_DeviceRemoved.revoke(); m_DeviceUpdated.revoke(); diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h index 2e5bbdfdd..ee365894a 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager.h @@ -14,10 +14,6 @@ using namespace winrt::Windows::Devices::Enumeration; -#define DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS 250 -#define KSA_INTERFACE_ENUM_TIMEOUT_MS_MINIMUM_VALUE 50 -#define KSA_INTERFACE_ENUM_TIMEOUT_MS_MAXIMUM_VALUE 2500 -#define KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE L"KsaInterfaceEnumTimeoutMS" struct KsAggregateEndpointMidiPinDefinition { @@ -59,49 +55,6 @@ struct KsAggregateEndpointDefinition }; - -// new structures because we need to be able to pull together -// virtual endpoints, which have greater than 16 ins and/or outs -// and so need the creation of multiple endpoints. Without the -// new 2603 approach, only 16 in and out ports are available -// per parent device (teVirtualMidi in this case). Also impacts -// loopBE30. - -struct KsAggregateEndpointDefinitionV2 -{ - std::wstring EndpointDeviceId{}; - - std::wstring EndpointName{}; - std::wstring EndpointDeviceInstanceId{}; - - std::vector MidiPins{ }; - - WindowsMidiServicesNamingLib::MidiEndpointNameTable EndpointNameTable{}; - - int8_t CurrentHighestMidiSourceGroupIndex{ -1 }; - int8_t CurrentHighestMidiDestinationGroupIndex{ -1 }; -}; - - -struct KsAggregateParentDeviceDefinitionV2 -{ - std::wstring DeviceName{}; - std::wstring DeviceInstanceId{}; - std::wstring DriverSuppliedDeviceName{}; // value from registry. Required for WinMM classic naming - - uint16_t VID{ 0 }; // USB-only - uint16_t PID{ 0 }; // USB-only - std::wstring SerialNumber{}; - - std::wstring ManufacturerName{}; - - std::vector Endpoints{ }; // most devices will have just one endpoint, but virtual can have > 1 -}; - - - - - class CMidi2KSAggregateMidiEndpointManager : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, @@ -117,8 +70,7 @@ class CMidi2KSAggregateMidiEndpointManager : private: STDMETHOD(CreateMidiUmpEndpoint)(_In_ KsAggregateEndpointDefinition& masterEndpointDefinition); - STDMETHOD(CreateMidiUmpEndpointV2)(_In_ std::shared_ptr masterEndpointDefinition); - + HRESULT OnDeviceAdded(_In_ DeviceWatcher, _In_ DeviceInformation); HRESULT OnDeviceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); HRESULT OnDeviceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); @@ -130,65 +82,7 @@ class CMidi2KSAggregateMidiEndpointManager : wil::critical_section m_availableEndpointDefinitionsLock; std::map m_availableEndpointDefinitions; - - wil::critical_section m_pendingEndpointDefinitionsLock; - std::vector> m_pendingEndpointDefinitions; - - - - // new interface-based approachfor 2603 CFR update - HRESULT OnFilterDeviceInterfaceAdded(_In_ DeviceWatcher, _In_ DeviceInformation); - HRESULT OnFilterDeviceInterfaceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); - HRESULT OnFilterDeviceInterfaceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); - - std::map> m_availableEndpointDefinitionsV2; - std::vector> m_pendingEndpointDefinitionsV2; - - - HRESULT FindActivatedEndpointDefinitionForFilterDevice( - _In_ std::wstring parentDeviceInstanceId, - _In_ std::shared_ptr&); - - HRESULT FindOrCreatePendingEndpointDefinitionForFilterDevice( - _In_ DeviceInformation, - _In_ std::shared_ptr&); - HRESULT GetMidi1FilterPins( - _In_ DeviceInformation, - _In_ std::vector&); - - bool KSAEndpointForDeviceExists( - _In_ std::wstring deviceInstanceId); - - HRESULT IncrementAndGetNextGroupIndex( - _In_ std::shared_ptr definition, - _In_ MidiFlow dataFlowFromUserPerspective, - _In_ uint8_t& groupIndex); - - HRESULT UpdateNewPinDefinitions( - _In_ std::wstring filterDeviceid, - _In_ std::wstring driverSuppliedName, - _In_ std::shared_ptr endpointDefinition); - - HRESULT BuildPinsAndGroupTerminalBlocksPropertyData( - _In_ std::shared_ptr masterEndpointDefinition, - _In_ std::vector& pinMapPropertyData, - _In_ std::vector& groupTerminalBlocks); - - HRESULT UpdateNameTableWithCustomProperties( - _In_ std::shared_ptr masterEndpointDefinition, - _In_ std::shared_ptr customProperties); - - wil::unique_event_nothrow m_endpointCreationThreadWakeup; - std::jthread m_endpointCreationThread; - void EndpointCreationThreadWorker(_In_ std::stop_token token); - - HRESULT UpdateExistingMidiUmpEndpointWithFilterChanges( - _In_ std::shared_ptr masterEndpointDefinition); - - - - DeviceWatcher m_watcher{0}; winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Added_revoker m_DeviceAdded; winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Removed_revoker m_DeviceRemoved; @@ -199,5 +93,5 @@ class CMidi2KSAggregateMidiEndpointManager : HRESULT GetKSDriverSuppliedName(_In_ HANDLE hFilter, _Inout_ std::wstring& name); - DWORD m_individualInterfaceEnumTimeoutMS { DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; + }; diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp new file mode 100644 index 000000000..3624278ca --- /dev/null +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp @@ -0,0 +1,1962 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License +// ============================================================================ +// This is part of the Windows MIDI Services App API and should be used +// in your Windows application via an official binary distribution. +// Further information: https://aka.ms/midi +// ============================================================================ + + + +#include "pch.h" + +#include // for the string stream in parsing of VID/PID/Serial from parent id +#include // for getline for string parsing of VID/PID/Serial from parent id + + +using namespace wil; +using namespace winrt::Windows::Devices::Enumeration; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; +using namespace Microsoft::WRL; +using namespace Microsoft::WRL::Wrappers; + +#define INITIAL_ENUMERATION_TIMEOUT_MS 10000 + + + + +_Use_decl_annotations_ +HRESULT +KsAggregateParentDeviceDefinition2::AddPin( + std::shared_ptr pin) +{ + UNREFERENCED_PARAMETER(pin); + + // TODO + + // do we have any endpoints with this filter? + // if so, check to see if it has space for more pins + + // if no existing endpoint with this filter, then check to see if we have any + // endpoints at all we can add this to + + // if no endpoints at all, or if all other endpoints are filled, create + // a new endpoint + + + // Add this pin to the endpoint + + + + return S_OK; +} + + + + + + + + + + + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::Initialize( + IMidiDeviceManager* midiDeviceManager, + IMidiEndpointProtocolManager* midiEndpointProtocolManager +) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + RETURN_HR_IF(E_INVALIDARG, nullptr == midiDeviceManager); + RETURN_HR_IF(E_INVALIDARG, nullptr == midiEndpointProtocolManager); + + RETURN_IF_FAILED(midiDeviceManager->QueryInterface(__uuidof(IMidiDeviceManager), (void**)&m_midiDeviceManager)); + RETURN_IF_FAILED(midiEndpointProtocolManager->QueryInterface(__uuidof(IMidiEndpointProtocolManager), (void**)&m_midiProtocolManager)); + + // needed for internal consumption. Gary to replace this with feature enablement check + // defined in pch.h + DWORD individualInterfaceEnumTimeoutMS{ DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; + if (SUCCEEDED(wil::reg::get_value_dword_nothrow(HKEY_LOCAL_MACHINE, MIDI_ROOT_REG_KEY, KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE, &individualInterfaceEnumTimeoutMS))) + { + individualInterfaceEnumTimeoutMS = max(individualInterfaceEnumTimeoutMS, KSA_INTERFACE_ENUM_TIMEOUT_MS_MINIMUM_VALUE); + individualInterfaceEnumTimeoutMS = min(individualInterfaceEnumTimeoutMS, KSA_INTERFACE_ENUM_TIMEOUT_MS_MAXIMUM_VALUE); + + m_individualInterfaceEnumTimeoutMS = individualInterfaceEnumTimeoutMS; + } + else + { + m_individualInterfaceEnumTimeoutMS = DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS; + } + + + // the ksa2603 fix enumerates device interfaces instead of parent devices + + winrt::hstring deviceInterfaceSelector( + L"System.Devices.InterfaceClassGuid:=\"{6994AD04-93EF-11D0-A3CC-00A0C9223196}\" AND " \ + L"System.Devices.InterfaceEnabled: = System.StructuredQueryType.Boolean#True"); + + auto additionalProps = winrt::single_threaded_vector(); + additionalProps.Append(L"System.Devices.Parent"); + + m_watcher = DeviceInformation::CreateWatcher(deviceInterfaceSelector); + + auto deviceAddedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded); + auto deviceRemovedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved); + auto deviceUpdatedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceUpdated); + + auto deviceStoppedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnDeviceWatcherStopped); + auto deviceEnumerationCompletedHandler = TypedEventHandler(this, &CMidi2KSAggregateMidiEndpointManager2::OnEnumerationCompleted); + + m_DeviceAdded = m_watcher.Added(winrt::auto_revoke, deviceAddedHandler); + m_DeviceRemoved = m_watcher.Removed(winrt::auto_revoke, deviceRemovedHandler); + m_DeviceUpdated = m_watcher.Updated(winrt::auto_revoke, deviceUpdatedHandler); + m_DeviceStopped = m_watcher.Stopped(winrt::auto_revoke, deviceStoppedHandler); + m_DeviceEnumerationCompleted = m_watcher.EnumerationCompleted(winrt::auto_revoke, deviceEnumerationCompletedHandler); + + // worker thread to handle endpoint creation, since we're enumerating interfaces now and need to aggregate them + m_endpointCreationThreadWakeup.create(wil::EventOptions::ManualReset); + std::jthread endpointCreationWorkerThread(std::bind_front(&CMidi2KSAggregateMidiEndpointManager2::EndpointCreationThreadWorker, this)); + m_endpointCreationThread = std::move(endpointCreationWorkerThread); + + m_watcher.Start(); + + // Wait for everything to be created so that they're available immediately after service start. + m_EnumerationCompleted.wait(INITIAL_ENUMERATION_TIMEOUT_MS); + + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + { + if (m_pendingEndpointDefinitions.size() > 0) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enumeration completed with endpoint definitions left pending.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingUInt32(static_cast(m_pendingEndpointDefinitions.size()), "count pending definitions") + ); + } + } + + return S_OK; +} + + +typedef struct { + BYTE GroupIndex; // index (0-15) of the group this pin maps to. It's also use when deciding on WinMM names + UINT32 PinId; // KS Pin number + MidiFlow PinDataFlow; // an input pin is MidiFlowIn, and from the user's perspective, a MIDI Output + std::wstring FilterId; // full filter id for this pin +} PinMapEntryStagingEntry; + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::UpdateNameTableWithCustomProperties( + std::shared_ptr masterEndpointDefinition, + std::shared_ptr customProperties) +{ + RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); + RETURN_HR_IF_NULL_EXPECTED(S_OK, customProperties); + RETURN_HR_IF(S_OK, customProperties->Midi1Destinations.size() == 0 && customProperties->Midi1Sources.size() == 0); + + for (auto const& pinEntry : masterEndpointDefinition->MidiPins) + { + if (pinEntry->PinDataFlow == MidiFlow::MidiFlowIn) + { + // message destination (output port), pin flow is In + if (auto customConfiguredName = customProperties->Midi1Destinations.find(pinEntry->GroupIndex); + customConfiguredName != customProperties->Midi1Destinations.end()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found custom name for a Midi 1 destination.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingWideString(customConfiguredName->second.Name.c_str(), "custom name"), + TraceLoggingUInt8(pinEntry->GroupIndex, "group index") + ); + + masterEndpointDefinition->EndpointNameTable.UpdateDestinationEntryCustomName(pinEntry->GroupIndex, customConfiguredName->second.Name); + } + } + else if (pinEntry->PinDataFlow == MidiFlow::MidiFlowOut) + { + // message source (input port), pin flow is Out + if (auto customConfiguredName = customProperties->Midi1Sources.find(pinEntry->GroupIndex); + customConfiguredName != customProperties->Midi1Sources.end()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found custom name for a Midi 1 source.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingWideString(customConfiguredName->second.Name.c_str(), "custom name"), + TraceLoggingUInt8(pinEntry->GroupIndex, "group index") + ); + + masterEndpointDefinition->EndpointNameTable.UpdateSourceEntryCustomName(pinEntry->GroupIndex, customConfiguredName->second.Name); + } + } + } + + return S_OK; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::BuildPinsAndGroupTerminalBlocksPropertyData( + std::shared_ptr masterEndpointDefinition, + std::vector& pinMapPropertyData, + std::vector& groupTerminalBlocks) +{ + RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); + + uint8_t currentBlockNumber{ 0 }; + std::vector pinMapEntries{ }; + + for (auto const& pin : masterEndpointDefinition->MidiPins) + { + RETURN_HR_IF(E_INVALIDARG, pin->FilterDeviceId.empty()); + + internal::GroupTerminalBlockInternal gtb; + + gtb.Number = ++currentBlockNumber; + gtb.GroupCount = 1; // always a single group for aggregate MIDI 1.0 devices + + PinMapEntryStagingEntry pinMapEntry{ }; + + pinMapEntry.PinId = pin->PinNumber; + pinMapEntry.FilterId = pin->FilterDeviceId; + pinMapEntry.PinDataFlow = pin->PinDataFlow; + + //MidiFlow flowFromUserPerspective; + + pinMapEntry.GroupIndex = pin->GroupIndex; + gtb.FirstGroupIndex = pin->GroupIndex; + + if (pin->PinDataFlow == MidiFlow::MidiFlowIn) // pin in, so user out : A MIDI Destination + { + gtb.Direction = MIDI_GROUP_TERMINAL_BLOCK_INPUT; // from the pin/gtb's perspective + + auto nameTableEntry = masterEndpointDefinition->EndpointNameTable.GetDestinationEntry(gtb.FirstGroupIndex); + if (nameTableEntry != nullptr && nameTableEntry->NewStyleName[0] != static_cast(0)) + { + gtb.Name = internal::TrimmedWStringCopy(nameTableEntry->NewStyleName); + } + } + else if (pin->PinDataFlow == MidiFlow::MidiFlowOut) // pin out, so user in : A MIDI Source + { + gtb.Direction = MIDI_GROUP_TERMINAL_BLOCK_OUTPUT; // from the pin/gtb's perspective + auto nameTableEntry = masterEndpointDefinition->EndpointNameTable.GetSourceEntry(gtb.FirstGroupIndex); + if (nameTableEntry != nullptr && nameTableEntry->NewStyleName[0] != static_cast(0)) + { + gtb.Name = internal::TrimmedWStringCopy(nameTableEntry->NewStyleName); + } + } + else + { + RETURN_IF_FAILED(E_INVALIDARG); + } + + // name fallback + if (gtb.Name.empty()) + { + gtb.Name = masterEndpointDefinition->EndpointName; + + if (gtb.FirstGroupIndex > 0) + { + gtb.Name += L" " + std::wstring{ gtb.FirstGroupIndex }; + } + } + + // default values as defined in the MIDI 2.0 USB spec + gtb.Protocol = 0x01; // midi 1.0 + gtb.MaxInputBandwidth = 0x0001; // 31.25 kbps + gtb.MaxOutputBandwidth = 0x0001; // 31.25 kbps + + groupTerminalBlocks.push_back(gtb); + pinMapEntries.push_back(pinMapEntry); + } + + // Write Pin Map Property + // ===================================================== + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Building pin map property", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + // build the pin map property value + KSAGGMIDI_PIN_MAP_PROPERTY_VALUE pinMap{ }; + + size_t totalStringSizesIncludingNulls{ 0 }; + for (auto const& entry : pinMapEntries) + { + totalStringSizesIncludingNulls += ((entry.FilterId.length() + 1) * sizeof(wchar_t)); + } + + size_t totalMemoryBytes{ + SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_VALUE_HEADER + + SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_ENTRY_WITHOUT_STRING * pinMapEntries.size() + + totalStringSizesIncludingNulls }; + + pinMapPropertyData.resize(totalMemoryBytes); + auto currentPos = pinMapPropertyData.data(); + + // header + auto pinMapHeader = (PKSAGGMIDI_PIN_MAP_PROPERTY_VALUE)currentPos; + pinMapHeader->TotalByteCount = (UINT32)totalMemoryBytes; + currentPos += SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_VALUE_HEADER; + + for (auto const& entry : pinMapEntries) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Processing Pin Map entry", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingUInt32(entry.PinId, "Pin Id"), + TraceLoggingWideString(entry.FilterId.c_str(), "Filter Id") + ); + + PKSAGGMIDI_PIN_MAP_PROPERTY_ENTRY propEntry = (PKSAGGMIDI_PIN_MAP_PROPERTY_ENTRY)currentPos; + + propEntry->ByteCount = (UINT)(SIZET_KSAGGMIDI_PIN_MAP_PROPERTY_ENTRY_WITHOUT_STRING + ((entry.FilterId.length() + 1) * sizeof(wchar_t))); + propEntry->GroupIndex = entry.GroupIndex; + propEntry->PinDataFlow = entry.PinDataFlow; + propEntry->PinId = entry.PinId; + + if (!entry.FilterId.empty()) + { + wcscpy_s((wchar_t*)propEntry->FilterId, entry.FilterId.length() + 1, entry.FilterId.c_str()); + } + + currentPos += propEntry->ByteCount; + } + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"All pin map entries copied to property memory", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + + return S_OK; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::CreateMidiUmpEndpoint( + std::shared_ptr masterEndpointDefinition, + std::shared_ptr parentDevice +) +{ + RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); + RETURN_HR_IF_NULL(E_INVALIDARG, parentDevice); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + DEVPROP_BOOLEAN devPropTrue = DEVPROP_TRUE; + + // we require at least one valid pin + RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition->MidiPins.size() < 1); + + std::vector interfaceDevProperties; + + MIDIENDPOINTCOMMONPROPERTIES commonProperties{}; + commonProperties.TransportId = TRANSPORT_LAYER_GUID; + commonProperties.EndpointDeviceType = MidiEndpointDeviceType_Normal; + commonProperties.FriendlyName = masterEndpointDefinition->EndpointName.c_str(); + commonProperties.TransportCode = TRANSPORT_CODE; + commonProperties.EndpointName = masterEndpointDefinition->EndpointName.c_str(); + commonProperties.EndpointDescription = nullptr; + commonProperties.CustomEndpointName = nullptr; + commonProperties.CustomEndpointDescription = nullptr; + commonProperties.UniqueIdentifier = parentDevice->SerialNumber.empty() ? nullptr : parentDevice->SerialNumber.c_str(); + commonProperties.ManufacturerName = parentDevice->ManufacturerName.empty() ? nullptr : parentDevice->ManufacturerName.c_str(); + commonProperties.SupportedDataFormats = MidiDataFormats::MidiDataFormats_UMP; + commonProperties.NativeDataFormat = MidiDataFormats_ByteStream; + + UINT32 capabilities{ 0 }; + capabilities |= MidiEndpointCapabilities_SupportsMultiClient; + capabilities |= MidiEndpointCapabilities_GenerateIncomingTimestamps; + capabilities |= MidiEndpointCapabilities_SupportsMidi1Protocol; + commonProperties.Capabilities = (MidiEndpointCapabilities)capabilities; + + + std::vector pinMapPropertyData; + std::vector groupTerminalBlocks{ }; + std::vector nameTablePropertyData; + + RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData( + masterEndpointDefinition, + pinMapPropertyData, + groupTerminalBlocks)); + + + interfaceDevProperties.push_back({ { DEVPKEY_KsAggMidiGroupPinMap, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BINARY, static_cast(pinMapPropertyData.size()), pinMapPropertyData.data() }); + + + // Write Group Terminal Block Property + // ===================================================== + + std::vector groupTerminalBlockPropertyData{}; + + if (internal::WriteGroupTerminalBlocksToPropertyDataPointer(groupTerminalBlocks, groupTerminalBlockPropertyData)) + { + interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BINARY, static_cast(groupTerminalBlockPropertyData.size()), + (PVOID)groupTerminalBlockPropertyData.data() }); + } + else + { + // write empty data + } + + + // Fold in custom properties, including MIDI 1 port names and naming approach + // =============================================================================== + + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{}; + matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->EndpointDeviceInstanceId); + matchCriteria.UsbVendorId = parentDevice->VID; + matchCriteria.UsbProductId = parentDevice->PID; + matchCriteria.UsbSerialNumber = parentDevice->SerialNumber; + matchCriteria.TransportSuppliedEndpointName = masterEndpointDefinition->EndpointName; + + auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria); + + // rebuild the name table, using the custom properties if present + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); + + std::wstring customName{ }; + std::wstring customDescription{ }; + if (customProperties != nullptr) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"), + TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count") + ); + + if (!customProperties->Name.empty()) + { + customName = customProperties->Name; + commonProperties.CustomEndpointName = customName.c_str(); + } + + if (!customProperties->Description.empty()) + { + customDescription = customProperties->Description; + commonProperties.CustomEndpointDescription = customDescription.c_str(); + } + + // this includes image, the Midi 1 naming approach, etc. + customProperties->WriteNonCommonProperties(interfaceDevProperties); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) + ); + } + + // Write Name table property, folding in the custom names we discovered earlier + // =============================================================================================== + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); + masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + + + // Write USB VID/PID Data + // ===================================================== + + if (parentDevice->VID > 0) + { + interfaceDevProperties.push_back({ { PKEY_MIDI_UsbVID, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_UINT16, static_cast(sizeof(UINT16)), (PVOID)&parentDevice->VID }); + } + else + { + interfaceDevProperties.push_back({ { PKEY_MIDI_UsbVID, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_EMPTY, 0, nullptr }); + } + + if (parentDevice->PID > 0) + { + interfaceDevProperties.push_back({ { PKEY_MIDI_UsbPID, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_UINT16, static_cast(sizeof(UINT16)), (PVOID)&parentDevice->PID }); + } + else + { + interfaceDevProperties.push_back({ { PKEY_MIDI_UsbPID, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_EMPTY, 0, nullptr }); + } + + + // Despite being a MIDI 1 device, we present as a UMP endpoint, so we need to set + // this property so the service can create the MIDI 1 ports without waiting for + // function blocks/discovery to complete or timeout (which it never will) + interfaceDevProperties.push_back({ { PKEY_MIDI_EndpointDiscoveryProcessComplete, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BOOLEAN, (ULONG)sizeof(devPropTrue), (PVOID)&devPropTrue }); + + SW_DEVICE_CREATE_INFO createInfo{ }; + + createInfo.cbSize = sizeof(createInfo); + createInfo.pszInstanceId = masterEndpointDefinition->EndpointDeviceInstanceId.c_str(); + createInfo.CapabilityFlags = SWDeviceCapabilitiesNone; + createInfo.pszDeviceDescription = masterEndpointDefinition->EndpointName.c_str(); + + // Call the device manager and finish the creation + + HRESULT swdCreationResult; + wil::unique_cotaskmem_string newDeviceInterfaceId; + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Activating endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + // set to true if we only want to create UMP endpoints + bool umpOnly = false; + + LOG_IF_FAILED( + swdCreationResult = m_midiDeviceManager->ActivateEndpoint( + parentDevice->DeviceInstanceId.c_str(), + umpOnly, + MidiFlow::MidiFlowBidirectional, // bidi only for the UMP endpoint + &commonProperties, + (ULONG)interfaceDevProperties.size(), + (ULONG)0, + interfaceDevProperties.data(), + nullptr, + &createInfo, + &newDeviceInterfaceId) + ); + + if (SUCCEEDED(swdCreationResult)) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Aggregate UMP endpoint created", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingWideString(newDeviceInterfaceId.get(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) + ); + + // return new device interface id + masterEndpointDefinition->EndpointDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(std::wstring{ newDeviceInterfaceId.get() }); + + auto lock = m_availableEndpointDefinitionsLock.lock(); + + // Add to internal endpoint manager + m_availableEndpointDefinitions.insert_or_assign( + internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice->DeviceInstanceId), + masterEndpointDefinition); + + return swdCreationResult; + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Aggregate UMP endpoint creation failed", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingHResult(swdCreationResult, MIDI_TRACE_EVENT_HRESULT_FIELD) + ); + + return swdCreationResult; + } +} + + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::UpdateExistingMidiUmpEndpointWithFilterChanges( + std::shared_ptr masterEndpointDefinition, + std::shared_ptr parentDevice +) +{ + RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); + RETURN_HR_IF_NULL(E_INVALIDARG, parentDevice); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + ); + + // we require at least one valid pin + RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition->MidiPins.size() < 1); + + std::vector interfaceDevProperties{ }; + + std::vector pinMapPropertyData; + std::vector groupTerminalBlocks{ }; + std::vector nameTablePropertyData; + + // update the pin map to have all the existing pins + // plus the new pins. Update Group Terminal Blocks at the same time. + RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData( + masterEndpointDefinition, + pinMapPropertyData, + groupTerminalBlocks)); + + + // Write Pin Map Property + // ===================================================== + interfaceDevProperties.push_back({ { DEVPKEY_KsAggMidiGroupPinMap, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BINARY, static_cast(pinMapPropertyData.size()), pinMapPropertyData.data() }); + + + // Write Group Terminal Block Property + // ===================================================== + + std::vector groupTerminalBlockData; + + if (internal::WriteGroupTerminalBlocksToPropertyDataPointer(groupTerminalBlocks, groupTerminalBlockData)) + { + interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_BINARY, (ULONG)groupTerminalBlockData.size(), (PVOID)groupTerminalBlockData.data() }); + } + else + { + // write empty data + } + + + // Fold in custom properties, including MIDI 1 port names and naming approach + // =============================================================================== + + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{}; + matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->EndpointDeviceInstanceId); + matchCriteria.UsbVendorId = parentDevice->VID; + matchCriteria.UsbProductId = parentDevice->PID; + matchCriteria.UsbSerialNumber = parentDevice->SerialNumber; + matchCriteria.TransportSuppliedEndpointName = masterEndpointDefinition->EndpointName; + + auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria); + + std::wstring customName{ }; + std::wstring customDescription{ }; + if (customProperties != nullptr) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"), + TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count") + ); + + // this includes image, the Midi 1 naming approach, etc. + customProperties->WriteNonCommonProperties(interfaceDevProperties); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) + ); + } + + // store the property data for the name table + masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + + + // Write Name table property, folding in the custom names we discovered earlier + // =============================================================================================== + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); + masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + + HRESULT updateResult{}; + + LOG_IF_FAILED(updateResult = m_midiDeviceManager->UpdateEndpointProperties( + masterEndpointDefinition->EndpointDeviceId.c_str(), + static_cast(interfaceDevProperties.size()), + interfaceDevProperties.data() + )); + + + if (SUCCEEDED(updateResult)) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Aggregate UMP endpoint updated with new filter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceId.c_str(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) + ); + + auto lock = m_availableEndpointDefinitionsLock.lock(); + + // Add to internal endpoint manager + m_availableEndpointDefinitions.insert_or_assign( + internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice->DeviceInstanceId), + masterEndpointDefinition); + + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Aggregate UMP endpoint update failed", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingHResult(updateResult, MIDI_TRACE_EVENT_HRESULT_FIELD) + ); + } + + return updateResult; +} + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::GetPinName(HANDLE const hFilter, UINT const pinIndex, std::wstring& pinName) +{ + std::unique_ptr pinNameData; + ULONG pinNameDataSize{ 0 }; + + auto pinNameHR = PinPropertyAllocate( + hFilter, + pinIndex, + KSPROPSETID_Pin, + KSPROPERTY_PIN_NAME, + (PVOID*)&pinNameData, + &pinNameDataSize + ); + + if (SUCCEEDED(pinNameHR) || pinNameHR == HRESULT_FROM_WIN32(ERROR_SET_NOT_FOUND)) + { + // Check to see if the pin has an iJack name + if (pinNameDataSize > 0) + { + pinName = pinNameData.get(); + + return S_OK; + } + } + + return E_FAIL; +} + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::GetPinDataFlow(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ KSPIN_DATAFLOW& dataFlow) +{ + auto dataFlowHR = PinPropertySimple( + hFilter, + pinIndex, + KSPROPSETID_Pin, + KSPROPERTY_PIN_DATAFLOW, + &dataFlow, + sizeof(KSPIN_DATAFLOW) + ); + + if (SUCCEEDED(dataFlowHR)) + { + return S_OK; + } + + return E_FAIL; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::GetKSDriverSuppliedName(HANDLE hInstantiatedFilter, std::wstring& name) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + // get the name GUID + + KSCOMPONENTID componentId{}; + KSPROPERTY prop{}; + ULONG countBytesReturned{}; + + prop.Id = KSPROPERTY_GENERAL_COMPONENTID; + prop.Set = KSPROPSETID_General; + prop.Flags = KSPROPERTY_TYPE_GET; + + auto hrComponent = SyncIoctl( + hInstantiatedFilter, + IOCTL_KS_PROPERTY, + &prop, + sizeof(KSPROPERTY), + &componentId, + sizeof(KSCOMPONENTID), + &countBytesReturned + ); + + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + { + // changed to not log the failure here. Failures are expected for many devices, and it's adding noise to error logs + if (FAILED(hrComponent)) + { + return hrComponent; + } + } + else + { + RETURN_IF_FAILED(hrComponent); + } + + componentId.Name; // this is the GUID which points to the registry location with the driver-supplied name + + if (componentId.Name != GUID_NULL) + { + // we have the GUID where this name is stored, so get the driver-supplied name from the registry + + WCHAR nameFromRegistry[MAX_PATH]{ 0 }; // this should only be MAXPNAMELEN, but if someone tampered with it, could be larger, hence MAX_PATH + + std::wstring regKey = L"SYSTEM\\CurrentControlSet\\Control\\MediaCategories\\" + internal::GuidToString(componentId.Name); + + if (SUCCEEDED(wil::reg::get_value_string_nothrow(HKEY_LOCAL_MACHINE, regKey.c_str(), L"Name", nameFromRegistry))) + { + name = nameFromRegistry; + } + + return S_OK; + } + + return E_NOTFOUND; +} + + + +#define KS_CATEGORY_AUDIO_GUID L"{6994AD04-93EF-11D0-A3CC-00A0C9223196}" + + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::ParseParentIdIntoVidPidSerial( + winrt::hstring systemDevicesParentValue, + std::shared_ptr& parentDevice) +{ + + if (systemDevicesParentValue.empty()) + { + RETURN_IF_FAILED(E_INVALIDARG); + } + + + // Examples + // --------------------------------------------------------------------------- + // Parent values with serial: USB\VID_16C0&PID_05E4\ContinuuMini_SN024066 + // USB\VID_2573&PID_008A\no_serial_number (yes, this is the iSerialNumber in USB :/ + // 0x03 iSerialNumber "no serial number" + // USB\VID_12E6&PID_002C\251959d4f21fc283 + // + // Parent values without serial: USB\VID_F055&PID_0069\8&2858bbac&0&4 + // USB\VID_2662&PID_000D\8&24eb0394&0&4 + // USB\VID_05E3&PID_0610\7&2f028fc9&0&4 + // ROOT\MOTUBUS\0000 + + std::wstring parentVal = systemDevicesParentValue.c_str(); + + std::wstringstream ss(parentVal); + std::wstring usbSection{}; + + std::getline(ss, usbSection, static_cast('\\')); + + if (usbSection == L"USB") + { + // get the VID/PID section + + std::wstring vidPidSection{}; + + std::getline(ss, vidPidSection, static_cast('\\')); + + if (!vidPidSection.empty()) + { + std::wstring serialSection{}; + std::getline(ss, serialSection, static_cast('\\')); + + std::wstring vidPidString1{}; + std::wstring vidPidString2{}; + + std::wstringstream ssVidPid(vidPidSection); + std::getline(ssVidPid, vidPidString1, static_cast('&')); + std::getline(ssVidPid, vidPidString2, static_cast('&')); + + wchar_t* end{ nullptr }; + + // find the VID + if (vidPidString1.starts_with(L"VID_")) + { + parentDevice->VID = static_cast(wcstol(vidPidString1.substr(4).c_str(), &end, 16)); + } + else if (vidPidString2.starts_with(L"VID_")) + { + parentDevice->VID = static_cast(wcstol(vidPidString2.substr(4).c_str(), &end, 16)); + } + + // find the PID + if (vidPidString1.starts_with(L"PID_")) + { + parentDevice->PID = static_cast(wcstol(vidPidString1.substr(4).c_str(), &end, 16)); + } + else if (vidPidString2.starts_with(L"PID_")) + { + parentDevice->PID = static_cast(wcstol(vidPidString2.substr(4).c_str(), &end, 16)); + } + + // serial numbers with a & in them, are generated by our system + // it's possible a vendor may have a serial number with this in it, + // but in that case, we just ditch it. + if (serialSection.find_first_of('&') == serialSection.npos) + { + // Windows replaces spaces in the serial number with the underscore. + // yes, this will end up catching the few (if any) serials that do + // actually include an underscore. However, there are a bunch with spaces. + std::replace(serialSection.begin(), serialSection.end(), '_', ' '); + parentDevice->SerialNumber = serialSection; + } + } + } + else + { + // not a USB device, or otherwise uses a custom driver. We can't count + // on being able to parse the parent id. Example: MOTU has + // ROOT\MOTUBUS\0000 as the parent + } + + return S_OK; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::FindActivatedEndpointDefinitionForFilterDevice( + std::wstring parentDeviceInstanceId, + std::shared_ptr& endpointDefinition +) +{ + for (auto const& entry : m_availableEndpointDefinitions) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->DeviceInstanceId) == + internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) + { + endpointDefinition = entry.second; + + return S_OK; + } + } + + return E_NOTFOUND; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForFilterDevice( + DeviceInformation filterDevice, + std::shared_ptr& endpointDefinition +) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + // we require that the System.Devices.DeviceInstanceId property was requested for the passed-in filter device + auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); + RETURN_HR_IF(E_FAIL, deviceInstanceId.empty()); + + auto additionalProperties = winrt::single_threaded_vector(); + additionalProperties.Append(L"System.Devices.DeviceManufacturer"); + additionalProperties.Append(L"System.Devices.Manufacturer"); + additionalProperties.Append(L"System.Devices.Parent"); + + auto parentDevice = DeviceInformation::CreateFromIdAsync(deviceInstanceId, + additionalProperties, winrt::Windows::Devices::Enumeration::DeviceInformationKind::Device).get(); + + // See if we already have a pending master endpoint definition for this parent device + + auto lock = m_pendingEndpointDefinitionsLock.lock(); // we lock to avoid having one inserted while we're processing + + auto parentInstanceIdToFind = internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice.Id().c_str()); + auto it = std::find_if( + m_pendingEndpointDefinitions.begin(), + m_pendingEndpointDefinitions.end(), + [&parentInstanceIdToFind](const std::shared_ptr def){return internal::NormalizeDeviceInstanceIdWStringCopy(def->ParentDeviceInstanceId) == parentInstanceIdToFind; }); + + if (it != m_pendingEndpointDefinitions.end()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found existing aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentInstanceIdToFind.c_str(), "parent") + ); + + endpointDefinition = *it; + return S_OK; + } + + // We don't have one, create one and add, and get all the parent device information for it + // we still have the map locked, so keep this code fast + auto newEndpointDefinition = std::make_shared(); + RETURN_HR_IF_NULL(E_OUTOFMEMORY, newEndpointDefinition); + + //newEndpointDefinition->ParentDeviceName = parentDevice.Name(); + //newEndpointDefinition->EndpointName = parentDevice.Name(); + //newEndpointDefinition->ParentDeviceInstanceId = parentDevice.Id(); + + //LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(newEndpointDefinition->ParentDeviceInstanceId.c_str(), *newEndpointDefinition)); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Creating new aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentDevice.Id().c_str(), "parent") + ); + + // only some vendor drivers provide an actual manufacturer + // and all the in-box drivers just provide the Generic USB Audio string + // TODO: Is "Generic USB Audio" a string that is localized? If so, this + // code will not have the intended effect outside of en-US + //auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L""); + //if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft") + //{ + // newEndpointDefinition->ManufacturerName = manufacturer; + //} + + //// default hash is the device id. + //std::hash hasher; + //std::wstring hash; + //hash = std::to_wstring(hasher(newEndpointDefinition->ParentDeviceInstanceId)); + + //newEndpointDefinition->EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Adding pending aggregate UMP endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + m_pendingEndpointDefinitions.push_back(newEndpointDefinition); + endpointDefinition = newEndpointDefinition; + + + return S_OK; +} + +//_Use_decl_annotations_ +//HRESULT +//CMidi2KSAggregateMidiEndpointManager::IncrementAndGetNextGroupIndex( +// std::shared_ptr definition, +// MidiFlow dataFlowFromUserPerspective, +// uint8_t& groupIndex) +//{ +// if (dataFlowFromUserPerspective == MidiFlow::MidiFlowIn) +// { +// definition->CurrentHighestMidiSourceGroupIndex++; +// groupIndex = definition->CurrentHighestMidiSourceGroupIndex; +// } +// else +// { +// definition->CurrentHighestMidiDestinationGroupIndex++; +// groupIndex = definition->CurrentHighestMidiDestinationGroupIndex; +// } +// +// return S_OK; +//} + +#define MAX_THREAD_WORKER_WAIT_TIME_MS 20000 + +_Use_decl_annotations_ +void CMidi2KSAggregateMidiEndpointManager2::EndpointCreationThreadWorker( + std::stop_token token) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + while (!token.stop_requested()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Waiting to be woken up.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + // wait to be woken up + m_endpointCreationThreadWakeup.wait(MAX_THREAD_WORKER_WAIT_TIME_MS); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: I'm awake now.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + if (!token.stop_requested()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Short nap before checking.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingUInt32(m_individualInterfaceEnumTimeoutMS, "nap period (ms)") + ); + + // we sleep for this timeout before we check to see if the thread is signaled. + // this gives time for an additional interface notification to cause the event + // to be reset + Sleep(m_individualInterfaceEnumTimeoutMS); + + // if we're still signaled, that means no other pnp notifications came in during the nap, or if they + // did, they completed within that nap period + if (m_endpointCreationThreadWakeup.is_signaled() && m_pendingEndpointDefinitions.size() > 0) + { + m_endpointCreationThreadWakeup.ResetEvent(); // it's a manual reset event + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Thread was signaled and pending definition count > 0. Proceed to processing queue.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + + // lock the definitions so we can process them + auto lock = m_pendingEndpointDefinitionsLock.lock(); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Processing pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + while (m_pendingEndpointDefinitions.size() > 0) + { + // effectively a queue, but because we have to iterate and search + // this in other functions, a vector is more appropriate + auto ep = m_pendingEndpointDefinitions[0]; + m_pendingEndpointDefinitions.erase(m_pendingEndpointDefinitions.begin()); + + // create the endpoint + LOG_IF_FAILED(CreateMidiUmpEndpoint(ep)); + } + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Processed all pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + } + +#ifdef _DEBUG + else + { + if (m_pendingEndpointDefinitions.size() == 0) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Woken up, but no work to do. Pending count == 0.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"EndpointCreationWorker: Woken up, but thread is no longer signaled", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + } + } +#endif + } + } + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Exit", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + +} + +_Use_decl_annotations_ +bool CMidi2KSAggregateMidiEndpointManager2::ActiveKSAEndpointForDeviceExists( + _In_ std::wstring parentDeviceInstanceId) +{ + for (auto const& entry : m_availableEndpointDefinitions) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->DeviceInstanceId) == + internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) + { + return true; + } + } + + return false; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins( + DeviceInformation filterDevice, + std::vector& pinListToAddTo +) +{ + // Wrapper opens the handle internally. + KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); + RETURN_IF_FAILED(deviceHandleWrapper.Open()); + + // ============================================================================================= + // Go through all the enumerated pins, looking for a MIDI 1.0 pin + + // enumerate all the pins for this filter + ULONG cPins{ 0 }; + + RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return PinPropertySimple(h, 0, KSPROPSETID_Pin, KSPROPERTY_PIN_CTYPES, &cPins, sizeof(cPins)); + })); + + ULONG midiInputPinIndexForThisFilter{ 0 }; + ULONG midiOutputPinIndexForThisFilter{ 0 }; + + // process the pins for this filter. Not all will be MIDI pins + for (UINT pinIndex = 0; pinIndex < cPins; pinIndex++) + { + // Check the communication capabilities of the pin so we can fail fast + KSPIN_COMMUNICATION communication = (KSPIN_COMMUNICATION)0; + + RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return PinPropertySimple(h, pinIndex, KSPROPSETID_Pin, KSPROPERTY_PIN_COMMUNICATION, &communication, sizeof(KSPIN_COMMUNICATION)); + })); + + // The external connector pin representing the physical connection + // has KSPIN_COMMUNICATION_NONE. We can only create on software IO pins which + // have a valid communication. Skip connector pins. + if (communication == KSPIN_COMMUNICATION_NONE) + { + continue; + } + + // Duplicate the handle to safely pass it to another component or store it. + wil::unique_handle handleDupe(deviceHandleWrapper.GetHandle()); + RETURN_IF_NULL_ALLOC(handleDupe); + + // we try to open UMP only so we understand the device + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Checking for UMP pin. This will fallback error fail for non-UMP devices.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + + KsHandleWrapper m_PinHandleWrapperUmp(filterDevice.Id().c_str(), pinIndex, MidiTransport_CyclicUMP, handleDupe.get()); + if (SUCCEEDED(m_PinHandleWrapperUmp.Open())) + { + // this is a UMP pin. The KS transport will handle it, so we skip it here. + // In the future, we may want to bail on the first UMP pin we find. + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Found UMP/MIDI2 pin. Skipping for this transport.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + + continue; + } + + + // try to open as a MIDI 1 bytestream pin + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Checking for MIDI 1 pin. This will fallback error fail for non-MIDI devices.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + + KsHandleWrapper pinHandleWrapperMidi1(filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get()); + if (SUCCEEDED(pinHandleWrapperMidi1.Open())) + { + // this is a MIDI 1.0 byte format pin, so let's process it + KsAggregateEndpointMidiPinDefinition pinDefinition{ }; + + pinDefinition.PinNumber = pinIndex; + pinDefinition.FilterDeviceId = std::wstring{ filterDevice.Id() }; + pinDefinition.FilterName = std::wstring{ filterDevice.Name() }; + + // find the name of the pin (supplied by iJack, and if not available, generated by the stack) + std::wstring pinName{ }; + HRESULT pinNameHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetPinName(h, pinIndex, pinName); + }); + + if (SUCCEEDED(pinNameHr)) + { + pinDefinition.PinName = pinName; + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Pin has name", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(pinDefinition.PinName.c_str(), "pin name") + ); + } + + // get the data flow so we know if this is a MIDI Input (Source) or a MIDI Output (Destination) + KSPIN_DATAFLOW dataFlow = (KSPIN_DATAFLOW)0; + RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetPinDataFlow(h, pinIndex, dataFlow); + })); + + if (dataFlow == KSPIN_DATAFLOW_IN) + { + // MIDI Out (input to device) + pinDefinition.PinDataFlow = MidiFlow::MidiFlowIn; + pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowOut; // opposite + pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter); + + midiOutputPinIndexForThisFilter++; + } + else if (dataFlow == KSPIN_DATAFLOW_OUT) + { + // MIDI In (output from device) + pinDefinition.PinDataFlow = MidiFlow::MidiFlowOut; + pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowIn; // opposite + pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter); + + midiInputPinIndexForThisFilter++; + } + + pinListToAddTo.push_back(pinDefinition); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"MIDI 1.0 pin added", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + } + } + + + return S_OK; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( + std::wstring filterDeviceid, + std::wstring driverSuppliedName, + std::shared_ptr endpointDefinition) +{ + // At this point, we need to have *all* the pins for the endpoint, not just this filter + for (auto& pinDefinition : endpointDefinition->MidiPins) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(pinDefinition->FilterDeviceId) != + internal::NormalizeDeviceInstanceIdWStringCopy(filterDeviceid)) + { + // only process the pins for this filter interface. We don't want to + // change anything that has already been built. But we do need the + // context of all pins when getting the group index. + continue; + } + + // Figure out the group index for the pin. This needs the context of the + // entire device. Failure to get the group index is fatal +// RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDeviceid.c_str(), "filter device id"), + TraceLoggingUInt8(pinDefinition->GroupIndex, "group index"), + TraceLoggingWideString(pinDefinition->DataFlowFromUserPerspective == + MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") + ); + + std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update + + // Build the name table entry for this individual pin + endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( + pinDefinition->GroupIndex, + pinDefinition->DataFlowFromUserPerspective, + customName, + driverSuppliedName, + pinDefinition->FilterName, + pinDefinition->PinName, + pinDefinition->PortIndexWithinThisFilterAndDirection + ); + } + + return S_OK; +} + + + + +//HRESULT +//PopulatePinKSDataFormats(HANDLE filterHandle/*, Some_vector_of_pin_format_structs*/) +//{ +// //Try this, it should be a fairly easy thing to add to your change. +// // retrieve the : +// //KSPROPSETID_Pin, +// // KSPROPERTY_PIN_DATARANGES, +// +// // limit to pins with(pKsDataFormat->MajorFormat == KSDATAFORMAT_TYPE_MUSIC) +// // +// // Retrieval is going to follow the same ksmultipleitemp pattern as KSPROPERTY_MIDI2_GROUP_TERMINAL_BLOCKS +//} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( + DeviceWatcher /* watcher */, + DeviceInformation filterDevice +) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "added interface") + ); + + std::wstring transportCode(TRANSPORT_CODE); + + // Wrapper opens the handle internally. + KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); + RETURN_IF_FAILED(deviceHandleWrapper.Open()); + + std::shared_ptr endpointDefinition{ nullptr }; + + // get all the MIDI 1 pins. These are only partially processed because some things + // like group index and naming require the full context of all filters/pins on the + // parent device. + std::vector pinList{ }; + RETURN_IF_FAILED(GetMidi1FilterPins(filterDevice, pinList)); + + if (pinList.size() == 0) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No MIDI 1.0 pins for this filter device", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + + return S_OK; + } + + auto parentInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); + RETURN_HR_IF(E_FAIL, parentInstanceId.empty()); + + // Driver-supplied name. This is needed for WinMM backwards compatibility + std::wstring driverSuppliedName{}; + + // Using lamba function to prevent handle from dissapearing when being used. + // we don't log the HRESULT because this is not critical and will often fail, + // adding unnecessary noise to the error logs + deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetKSDriverSuppliedName(h, driverSuppliedName); + }); + + + // check to see if we already have an *activated* endpoint for this filter + if (ActiveKSAEndpointForDeviceExists(parentInstanceId.c_str())) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"KSA endpoint for this filter already activated. Updating it.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") + ); + + std::shared_ptr existingActivatedEndpointDefinition { nullptr }; + + // first MIDI 1 pin we're processing for this interface + RETURN_IF_FAILED(FindActivatedMasterEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); + RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); + + // add our new pins into the existing endpoint definition + existingActivatedEndpointDefinition->MidiPins.insert(existingActivatedEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); + RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), driverSuppliedName, existingActivatedEndpointDefinition)); + + RETURN_IF_FAILED(UpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition)); + + return S_OK; + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Endpoint for this device does not already exist. Proceed to creating new one.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") + ); + } + + + // if the endpointDefinition is null, that means we haven't found an existing + // activated endpoint definition we need to use, and so we proceed to check + // for an existing pending endpoint definition. If found, it's used. If not + // found, the function will create a new one for us to use, with all the + // endpoint-specific details (excluding pins) populated. + if (endpointDefinition == nullptr) + { + // first MIDI 1 pin we're processing for this interface + RETURN_IF_FAILED(FindOrCreatePendingMasterEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); + RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); + + // add our new pins into the existing endpoint definition + endpointDefinition->MidiPins.insert(endpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); + pinList.clear(); // just make sure we don't use this one, accidentally + } + + +#ifdef _DEBUG + if (!driverSuppliedName.empty()) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Driver-supplied name found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingWideString(driverSuppliedName.c_str(), "driver-supplied name") + ); + } +#endif + + // At this point, we need to have *all* the pins for the endpoint, not just this filter + for (auto& pinDefinition : endpointDefinition->MidiPins) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(pinDefinition.FilterDeviceId) != + internal::NormalizeDeviceInstanceIdWStringCopy(filterDevice.Id().c_str())) + { + // only process the pins for this filter interface. We don't want to + // change anything that has already been built. But we do need the + // context of all pins when getting the group index. + continue; + } + + // Figure out the group index for the pin. This needs the context of the + // entire device. Failure to get the group index is fatal + RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingUInt8(pinDefinition.GroupIndex, "group index"), + TraceLoggingWideString(pinDefinition.DataFlowFromUserPerspective == + MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") + ); + + std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update + + // Build the name table entry for this individual pin + endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( + pinDefinition.GroupIndex, + pinDefinition.DataFlowFromUserPerspective, + customName, + driverSuppliedName, + pinDefinition.FilterName, + pinDefinition.PinName, + pinDefinition.PortIndexWithinThisFilterAndDirection + ); + } + + // we have an endpoint definition + m_endpointCreationThreadWakeup.SetEvent(); + + return S_OK; +} + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved( + DeviceWatcher watcher, + DeviceInformationUpdate deviceInterfaceUpdate +) +{ + UNREFERENCED_PARAMETER(watcher); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "removed interface") + ); + + std::wstring removedFilterDeviceId{ internal::NormalizeDeviceInstanceIdWStringCopy(deviceInterfaceUpdate.Id().c_str()) }; + + // find an active device with this filter + + std::shared_ptr endpointDefinition{ nullptr }; + + + + for (auto& endpointListIterator : m_availableEndpointDefinitions) + { + // check pins for this filter + for (auto& pin: endpointListIterator.second->MidiPins) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(pin.FilterDeviceId) == removedFilterDeviceId) + { + endpointDefinition = endpointListIterator.second; + break; + } + } + + } + + if (endpointDefinition != nullptr) + { + bool done { false }; + + while (!done) + { + auto foundIt = std::find_if( + endpointDefinition->MidiPins.begin(), + endpointDefinition->MidiPins.end(), + [&removedFilterDeviceId](KsAggregateEndpointMidiPinDefinition& pin) { return internal::NormalizeDeviceInstanceIdWStringCopy(pin.FilterDeviceId) == removedFilterDeviceId; } + ); + + if (foundIt != endpointDefinition->MidiPins.end()) + { + // erase the pin definition with this + endpointDefinition->MidiPins.erase(foundIt); + } + else + { + // we've removed all the pins for this interface + done = true; + } + } + + if (endpointDefinition->MidiPins.size() > 0) + { + // we've removed all the pins for this interface, but there are still + // pins left, so now it's time to update the endpoint + + + // TODO: Need to cache the name from the driver/registry so we don't have to do a lookup here. + + // update remaining pins in existing endpoint definition + RETURN_IF_FAILED(UpdateNewPinDefinitions(removedFilterDeviceId, L"", endpointDefinition)); + RETURN_IF_FAILED(UpdateExistingMidiUmpEndpointWithFilterChanges(endpointDefinition)); + } + else + { + auto lock = m_availableEndpointDefinitionsLock.lock(); + + // notify the device manager using the InstanceId for this midi device + RETURN_IF_FAILED(m_midiDeviceManager->RemoveEndpoint( + internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->EndpointDeviceInstanceId).c_str())); + + // remove the endpoint from the list + + m_availableEndpointDefinitions.erase(internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->ParentDeviceInstanceId)); + } + } + + return S_OK; +} + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceUpdated( + DeviceWatcher watcher, + DeviceInformationUpdate deviceInterfaceUpdate +) +{ + UNREFERENCED_PARAMETER(watcher); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(deviceInterfaceUpdate.Id().c_str(), "updated interface") + ); + + // Flow for interface UPDATED + // - Check for any name changes + // - If any relevant changes recalculate GTBs and Name table as above and update properties + // + + + + return S_OK; +} + + + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::OnDeviceWatcherStopped(DeviceWatcher watcher, winrt::Windows::Foundation::IInspectable) +{ + UNREFERENCED_PARAMETER(watcher); + + m_EnumerationCompleted.SetEvent(); + return S_OK; +} + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::OnEnumerationCompleted(DeviceWatcher watcher, winrt::Windows::Foundation::IInspectable) +{ + UNREFERENCED_PARAMETER(watcher); + + m_EnumerationCompleted.SetEvent(); + return S_OK; +} + + +_Use_decl_annotations_ +winrt::hstring CMidi2KSAggregateMidiEndpointManager2::FindMatchingInstantiatedEndpoint(WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria& criteria) +{ + criteria.Normalize(); + + for (auto const& def : m_availableEndpointDefinitions) + { + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; + + available.DeviceInstanceId = def.second->DeviceInstanceId; + available.EndpointDeviceId = def.second->EndpointDeviceId; + available.UsbVendorId = def.second->VID; + available.UsbProductId = def.second->PID; + available.UsbSerialNumber = def.second->SerialNumber; + available.TransportSuppliedEndpointName = def.second->EndpointName; + available.DeviceManufacturerName = def.second->ManufacturerName; + + if (available.Matches(criteria)) + { + return available.EndpointDeviceId; + } + } + + return L""; +} + +HRESULT +CMidi2KSAggregateMidiEndpointManager2::Shutdown() +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this") + ); + + m_endpointCreationThread.request_stop(); + m_endpointCreationThreadWakeup.SetEvent(); + + m_DeviceAdded.revoke(); + m_DeviceRemoved.revoke(); + m_DeviceUpdated.revoke(); + m_DeviceStopped.revoke(); + + m_DeviceEnumerationCompleted.revoke(); + + m_watcher.Stop(); + + uint8_t tries{ 0 }; + while (m_watcher.Status() != DeviceWatcherStatus::Stopped && tries < 50) + { + Sleep(100); + tries++; + } + + TransportState::Current().Shutdown(); + + m_midiDeviceManager.reset(); + m_midiProtocolManager.reset(); + + return S_OK; +} \ No newline at end of file diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h new file mode 100644 index 000000000..35cbc3eae --- /dev/null +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License +// ============================================================================ +// This is part of the Windows MIDI Services App API and should be used +// in your Windows application via an official binary distribution. +// Further information: https://aka.ms/midi +// ============================================================================ + +#pragma once + +#include "MidiEndpointCustomProperties.h" +#include "MidiEndpointMatchCriteria.h" +#include "MidiEndpointNameTable.h" + +using namespace winrt::Windows::Devices::Enumeration; + +#define DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS 250 +#define KSA_INTERFACE_ENUM_TIMEOUT_MS_MINIMUM_VALUE 50 +#define KSA_INTERFACE_ENUM_TIMEOUT_MS_MAXIMUM_VALUE 2500 +#define KSA_INTERFACE_ENUM_TIMEOUT_REG_VALUE L"KsaInterfaceEnumTimeoutMS" + +struct KsAggregateEndpointMidiPinDefinition2 +{ + //std::wstring KSDriverSuppliedName; + std::wstring FilterDeviceId; // this is also the value needed by WinMM for DRV_QUERYDEVICEINTERFACE + std::wstring FilterName; + + ULONG PinNumber; + std::wstring PinName; + + MidiFlow PinDataFlow; + MidiFlow DataFlowFromUserPerspective; + + uint8_t GroupIndex{ 0 }; + uint8_t PortIndexWithinThisFilterAndDirection{ 0 }; // not always the same as the group index. Example: MOTU Express 128 with separate filter for each in/out pair + + // internal::Midi1PortNaming::Midi1PortNameEntry PortNames; +}; + +struct KsAggregateEndpointDefinition2 +{ + std::wstring EndpointDeviceId{}; + + std::wstring EndpointName{}; + std::wstring EndpointDeviceInstanceId{}; + + std::vector> MidiPins{ }; + + WindowsMidiServicesNamingLib::MidiEndpointNameTable EndpointNameTable{ }; +}; + + +class KsAggregateParentDeviceDefinition2 +{ +public: + std::wstring DeviceName{}; + std::wstring DeviceInstanceId{}; + std::wstring DriverSuppliedDeviceName{}; // value from registry. Required for WinMM classic naming + + uint16_t VID{ 0 }; // USB-only + uint16_t PID{ 0 }; // USB-only + std::wstring SerialNumber{}; + + std::wstring ManufacturerName{}; + + std::vector> Endpoints{ }; // most devices will have just one endpoint, but virtual can have > 1 + + + // This will add a pin, and create new endpoints as needed, assign the group index, etc. + // it will also update the name table given the info we have + HRESULT AddPin(_In_ std::shared_ptr pin); + + //HRESULT RemoveFilter(_In_ std::wstring filterId); + + + +private: + +}; + + +class CMidi2KSAggregateMidiEndpointManager2 : + public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IMidiEndpointManager> +{ +public: + + STDMETHOD(Initialize(_In_ IMidiDeviceManager*, _In_ IMidiEndpointProtocolManager*)); + STDMETHOD(Shutdown)(); + + // returns the endpointDeviceInterfaceId for a matching endpoint found in m_availableEndpointDefinitions + winrt::hstring FindMatchingInstantiatedEndpoint(_In_ WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria& criteria); + +private: + wil::com_ptr_nothrow m_midiDeviceManager; + wil::com_ptr_nothrow m_midiProtocolManager; + + wil::critical_section m_availableEndpointDefinitionsLock; + wil::critical_section m_pendingEndpointDefinitionsLock; + + HRESULT ParseParentIdIntoVidPidSerial( + _In_ winrt::hstring systemDevicesParentValue, + _In_ std::shared_ptr& parentDevice); + + HRESULT GetPinName(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ std::wstring& pinName); + HRESULT GetPinDataFlow(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ KSPIN_DATAFLOW& dataFlow); + + STDMETHOD(CreateMidiUmpEndpoint)( + _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::shared_ptr parentDevice); + + HRESULT OnFilterDeviceInterfaceAdded(_In_ DeviceWatcher, _In_ DeviceInformation); + HRESULT OnFilterDeviceInterfaceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); + HRESULT OnFilterDeviceInterfaceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); + + HRESULT OnEnumerationCompleted(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); + HRESULT OnDeviceWatcherStopped(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); + + // key is parent device instance id + std::map> m_availableEndpointDefinitions; + std::vector> m_pendingEndpointDefinitions; + + + HRESULT FindActivatedEndpointDefinitionForFilterDevice( + _In_ std::wstring parentDeviceInstanceId, + _In_ std::shared_ptr&); + + HRESULT FindOrCreatePendingEndpointDefinitionForFilterDevice( + _In_ DeviceInformation, + _In_ std::shared_ptr&); + + HRESULT GetMidi1FilterPins( + _In_ DeviceInformation, + _In_ std::vector&); + + bool ActiveKSAEndpointForDeviceExists( + _In_ std::wstring deviceInstanceId); + + //HRESULT IncrementAndGetNextGroupIndex( + // _In_ std::shared_ptr definition, + // _In_ MidiFlow dataFlowFromUserPerspective, + // _In_ uint8_t& groupIndex); + + HRESULT UpdateNewPinDefinitions( + _In_ std::wstring filterDeviceid, + _In_ std::wstring driverSuppliedName, + _In_ std::shared_ptr endpointDefinition); + + HRESULT BuildPinsAndGroupTerminalBlocksPropertyData( + _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::vector& pinMapPropertyData, + _In_ std::vector& groupTerminalBlocks); + + HRESULT UpdateNameTableWithCustomProperties( + _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::shared_ptr customProperties); + + wil::unique_event_nothrow m_endpointCreationThreadWakeup; + std::jthread m_endpointCreationThread; + void EndpointCreationThreadWorker(_In_ std::stop_token token); + + HRESULT UpdateExistingMidiUmpEndpointWithFilterChanges( + _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::shared_ptr parentDevice); + + + DeviceWatcher m_watcher{0}; + winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Added_revoker m_DeviceAdded; + winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Removed_revoker m_DeviceRemoved; + winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Updated_revoker m_DeviceUpdated; + winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Stopped_revoker m_DeviceStopped; + winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::EnumerationCompleted_revoker m_DeviceEnumerationCompleted; + wil::unique_event m_EnumerationCompleted{wil::EventOptions::None}; + + HRESULT GetKSDriverSuppliedName(_In_ HANDLE hFilter, _Inout_ std::wstring& name); + + DWORD m_individualInterfaceEnumTimeoutMS { DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; +}; diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp index f3339531d..e8cd71c79 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp @@ -44,14 +44,24 @@ CMidi2KSAggregateTransport::Activate( TraceLoggingWideString(L"IMidiEndpointManager", MIDI_TRACE_EVENT_INTERFACE_FIELD) ); - // check to see if this is the first time we're creating the endpoint manager. If so, create it. if (TransportState::Current().GetEndpointManager() == nullptr) { TransportState::Current().ConstructEndpointManager(); } - RETURN_IF_FAILED(TransportState::Current().GetEndpointManager()->QueryInterface(iid, activatedInterface)); + + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + { + RETURN_IF_FAILED(TransportState::Current().GetEndpointManager2()->QueryInterface(iid, activatedInterface)); + } + else + { + RETURN_IF_FAILED(TransportState::Current().GetEndpointManager()->QueryInterface(iid, activatedInterface)); + } + + + } else if (__uuidof(IMidiTransportConfigurationManager) == iid) { diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj index 822f30a55..7cbd5b341 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj @@ -359,6 +359,7 @@ + @@ -369,7 +370,7 @@ - + @@ -380,6 +381,7 @@ + @@ -392,7 +394,7 @@ - + diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters index 69c3f79d7..cabf18a58 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters @@ -39,7 +39,7 @@ Source Files - + Source Files @@ -48,6 +48,9 @@ Source Files + + Source Files + @@ -94,7 +97,7 @@ Header Files - + Header Files @@ -103,6 +106,9 @@ Header Files + + Header Files + diff --git a/src/api/Transport/KSAggregateTransport/TransportState.cpp b/src/api/Transport/KSAggregateTransport/TransportState.cpp index da0e8411c..2e723832a 100644 --- a/src/api/Transport/KSAggregateTransport/TransportState.cpp +++ b/src/api/Transport/KSAggregateTransport/TransportState.cpp @@ -25,12 +25,22 @@ TransportState& TransportState::Current() HRESULT TransportState::ConstructEndpointManager() { - RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&m_endpointManager)); + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + { + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&m_endpointManager2)); + } + else + { + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&m_endpointManager)); + } + return S_OK; } + + HRESULT TransportState::ConstructConfigurationManager() { diff --git a/src/api/Transport/KSAggregateTransport/TransportState.h b/src/api/Transport/KSAggregateTransport/TransportState.h index 177f10f33..57726be7b 100644 --- a/src/api/Transport/KSAggregateTransport/TransportState.h +++ b/src/api/Transport/KSAggregateTransport/TransportState.h @@ -23,7 +23,27 @@ class TransportState wil::com_ptr GetEndpointManager() { - return m_endpointManager; + if (!Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + { + return m_endpointManager; + } + else + { + return nullptr; + } + } + + // for Feature_Servicing_MIDI2VirtualPortDriversFix + wil::com_ptr GetEndpointManager2() + { + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + { + return m_endpointManager2; + } + else + { + return nullptr; + } } wil::com_ptr GetConfigurationManager() @@ -51,6 +71,10 @@ class TransportState wil::com_ptr m_endpointManager; + + // for Feature_Servicing_MIDI2VirtualPortDriversFix + wil::com_ptr m_endpointManager2 { nullptr }; + wil::com_ptr m_configurationManager; }; \ No newline at end of file diff --git a/src/api/Transport/KSAggregateTransport/pch.h b/src/api/Transport/KSAggregateTransport/pch.h index 7aa72d432..5362975c8 100644 --- a/src/api/Transport/KSAggregateTransport/pch.h +++ b/src/api/Transport/KSAggregateTransport/pch.h @@ -113,7 +113,10 @@ namespace internal = ::WindowsMidiServicesInternal; #include "Midi2UMP2BSTransform.h" #include "Midi2UMP2BSTransform_i.c" +#include "Feature_Servicing_MIDI2VirtualPortDriversFix.h" + class CMidi2KSAggregateMidiEndpointManager; +class CMidi2KSAggregateMidiEndpointManager2; class CMidi2KSAggregateMidiInProxy; class CMidi2KSAggregateMidiOutProxy; class CMidi2KSAggregateMidiConfigurationManager; @@ -126,6 +129,7 @@ class TransportState; #include "Midi2.KSAggregateMidi.h" #include "Midi2.KSAggregateMidiBidi.h" #include "Midi2.KSAggregateMidiEndpointManager.h" +#include "Midi2.KSAggregateMidiEndpointManager2.h" #include "Midi2.KSAggregateMidiConfigurationManager.h" #include "Midi2.KSAggregateMidiPluginMetadataProvider.h" From 82fa120a87d2a7d66471b3b47b1430739a0be6d1 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Mon, 16 Feb 2026 13:57:13 -0500 Subject: [PATCH 12/18] Working on loopMIDI and similar issues --- .../Midi2.KSAggregateMidiEndpointManager2.cpp | 356 +++++++++++++++--- .../Midi2.KSAggregateMidiEndpointManager2.h | 100 ++--- .../Midi2.KSAggregateTransport.rc | 8 +- .../Midi2.KSAggregateTransport.vcxproj | 2 +- ...Midi2.KSAggregateTransport.vcxproj.filters | 2 +- 5 files changed, 359 insertions(+), 109 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp index 3624278ca..3a5f2813d 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp @@ -53,7 +53,74 @@ KsAggregateParentDeviceDefinition2::AddPin( } +bool KsAggregateParentDeviceDefinition2::HasPendingEndpointDeviceChanges() +{ + for (auto& endpoint : Endpoints) + { + if (endpoint->HasPendingDeviceChanges) + { + return true; + } + } + + return false; +} + +_Use_decl_annotations_ +HRESULT +KsAggregateParentDeviceDefinition2::FindEndpointForFilter( + std::wstring filterDeviceId, + std::shared_ptr& foundEndpoint) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDeviceId.c_str(), "filter device id") + ); + + auto cleanFilterDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(filterDeviceId); + + for (auto const& endpoint : Endpoints) + { + for (auto const& pin : endpoint->MidiPins) + { + if (internal::NormalizeEndpointInterfaceIdWStringCopy(pin->FilterDeviceId) == cleanFilterDeviceId) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Matching Endpoint found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanFilterDeviceId.c_str(), "filter device id") + ); + foundEndpoint = endpoint; + + return S_OK; + } + } + } + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanFilterDeviceId.c_str(), "filter device id") + ); + + foundEndpoint = nullptr; + + return E_NOTFOUND; +} @@ -160,7 +227,7 @@ typedef struct { UINT32 PinId; // KS Pin number MidiFlow PinDataFlow; // an input pin is MidiFlowIn, and from the user's perspective, a MIDI Output std::wstring FilterId; // full filter id for this pin -} PinMapEntryStagingEntry; +} PinMapEntryStagingEntry2; _Use_decl_annotations_ @@ -233,7 +300,7 @@ CMidi2KSAggregateMidiEndpointManager2::BuildPinsAndGroupTerminalBlocksPropertyDa RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); uint8_t currentBlockNumber{ 0 }; - std::vector pinMapEntries{ }; + std::vector pinMapEntries{ }; for (auto const& pin : masterEndpointDefinition->MidiPins) { @@ -244,7 +311,7 @@ CMidi2KSAggregateMidiEndpointManager2::BuildPinsAndGroupTerminalBlocksPropertyDa gtb.Number = ++currentBlockNumber; gtb.GroupCount = 1; // always a single group for aggregate MIDI 1.0 devices - PinMapEntryStagingEntry pinMapEntry{ }; + PinMapEntryStagingEntry2 pinMapEntry{ }; pinMapEntry.PinId = pin->PinNumber; pinMapEntry.FilterId = pin->FilterDeviceId; @@ -380,7 +447,7 @@ CMidi2KSAggregateMidiEndpointManager2::BuildPinsAndGroupTerminalBlocksPropertyDa _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager2::CreateMidiUmpEndpoint( +CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( std::shared_ptr masterEndpointDefinition, std::shared_ptr parentDevice ) @@ -609,10 +676,10 @@ CMidi2KSAggregateMidiEndpointManager2::CreateMidiUmpEndpoint( // return new device interface id masterEndpointDefinition->EndpointDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(std::wstring{ newDeviceInterfaceId.get() }); - auto lock = m_availableEndpointDefinitionsLock.lock(); + auto lock = m_activatedEndpointDefinitionsLock.lock(); // Add to internal endpoint manager - m_availableEndpointDefinitions.insert_or_assign( + m_activatedEndpointDefinitions.insert_or_assign( internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice->DeviceInstanceId), masterEndpointDefinition); @@ -639,7 +706,7 @@ CMidi2KSAggregateMidiEndpointManager2::CreateMidiUmpEndpoint( _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager2::UpdateExistingMidiUmpEndpointWithFilterChanges( +CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFilterChanges( std::shared_ptr masterEndpointDefinition, std::shared_ptr parentDevice ) @@ -770,10 +837,10 @@ CMidi2KSAggregateMidiEndpointManager2::UpdateExistingMidiUmpEndpointWithFilterCh TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceId.c_str(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) ); - auto lock = m_availableEndpointDefinitionsLock.lock(); + auto lock = m_activatedEndpointDefinitionsLock.lock(); // Add to internal endpoint manager - m_availableEndpointDefinitions.insert_or_assign( + m_activatedEndpointDefinitions.insert_or_assign( internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice->DeviceInstanceId), masterEndpointDefinition); @@ -923,9 +990,10 @@ CMidi2KSAggregateMidiEndpointManager2::GetKSDriverSuppliedName(HANDLE hInstantia _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::ParseParentIdIntoVidPidSerial( - winrt::hstring systemDevicesParentValue, + std::wstring systemDevicesParentValue, std::shared_ptr& parentDevice) { + RETURN_HR_IF_NULL(E_INVALIDARG, parentDevice); if (systemDevicesParentValue.empty()) { @@ -1021,41 +1089,61 @@ CMidi2KSAggregateMidiEndpointManager2::ParseParentIdIntoVidPidSerial( _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::FindActivatedEndpointDefinitionForFilterDevice( - std::wstring parentDeviceInstanceId, + std::wstring filterDeviceId, std::shared_ptr& endpointDefinition ) { - for (auto const& entry : m_availableEndpointDefinitions) + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDeviceId.c_str(), "filter device id") + ); + + for (auto const& entry : m_activatedEndpointDefinitions) { - if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->DeviceInstanceId) == - internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) + if (SUCCEEDED(entry.second->FindEndpointForFilter(filterDeviceId, endpointDefinition))) { - endpointDefinition = entry.second; + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Matching Endpoint found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDeviceId.c_str(), "filter device id") + ); return S_OK; } } + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDeviceId.c_str(), "filter device id") + ); + + endpointDefinition = nullptr; + return E_NOTFOUND; } _Use_decl_annotations_ HRESULT -CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForFilterDevice( +CMidi2KSAggregateMidiEndpointManager2::FindOrCreateParentDeviceDefinitionForFilterDevice( DeviceInformation filterDevice, - std::shared_ptr& endpointDefinition + std::shared_ptr& parentDeviceDefinition ) { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - // we require that the System.Devices.DeviceInstanceId property was requested for the passed-in filter device auto deviceInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); RETURN_HR_IF(E_FAIL, deviceInstanceId.empty()); @@ -1065,20 +1153,17 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF additionalProperties.Append(L"System.Devices.Manufacturer"); additionalProperties.Append(L"System.Devices.Parent"); - auto parentDevice = DeviceInformation::CreateFromIdAsync(deviceInstanceId, - additionalProperties, winrt::Windows::Devices::Enumeration::DeviceInformationKind::Device).get(); + auto parentDevice = DeviceInformation::CreateFromIdAsync( + deviceInstanceId, + additionalProperties, + winrt::Windows::Devices::Enumeration::DeviceInformationKind::Device).get(); - // See if we already have a pending master endpoint definition for this parent device + + auto lock = m_allParentDeviceDefinitionsLock.lock(); // we lock to avoid having one inserted while we're processing - auto lock = m_pendingEndpointDefinitionsLock.lock(); // we lock to avoid having one inserted while we're processing - - auto parentInstanceIdToFind = internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice.Id().c_str()); - auto it = std::find_if( - m_pendingEndpointDefinitions.begin(), - m_pendingEndpointDefinitions.end(), - [&parentInstanceIdToFind](const std::shared_ptr def){return internal::NormalizeDeviceInstanceIdWStringCopy(def->ParentDeviceInstanceId) == parentInstanceIdToFind; }); + auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice.Id().c_str()); - if (it != m_pendingEndpointDefinitions.end()) + if (auto it = m_allParentDeviceDefinitions.find(cleanParentDeviceInstanceId); it != m_allParentDeviceDefinitions.end()) { TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1086,24 +1171,18 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Found existing aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(parentInstanceIdToFind.c_str(), "parent") + TraceLoggingWideString(L"Found existing parent device.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent") ); - endpointDefinition = *it; + // we found a matching parent device. Return it. + parentDeviceDefinition = it->second; + return S_OK; } // We don't have one, create one and add, and get all the parent device information for it // we still have the map locked, so keep this code fast - auto newEndpointDefinition = std::make_shared(); - RETURN_HR_IF_NULL(E_OUTOFMEMORY, newEndpointDefinition); - - //newEndpointDefinition->ParentDeviceName = parentDevice.Name(); - //newEndpointDefinition->EndpointName = parentDevice.Name(); - //newEndpointDefinition->ParentDeviceInstanceId = parentDevice.Id(); - - //LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(newEndpointDefinition->ParentDeviceInstanceId.c_str(), *newEndpointDefinition)); TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1111,26 +1190,161 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Creating new aggregate UMP endpoint definition.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(parentDevice.Id().c_str(), "parent") + TraceLoggingWideString(L"Parent device definition not already created. Creating now.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent") ); + auto newParentDeviceDefinition = std::make_shared(); + RETURN_HR_IF_NULL(E_OUTOFMEMORY, newParentDeviceDefinition); + + newParentDeviceDefinition->DeviceName = parentDevice.Name(); + newParentDeviceDefinition->DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice.Id().c_str()); + + LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(newParentDeviceDefinition->DeviceInstanceId, newParentDeviceDefinition)); + // only some vendor drivers provide an actual manufacturer // and all the in-box drivers just provide the Generic USB Audio string // TODO: Is "Generic USB Audio" a string that is localized? If so, this // code will not have the intended effect outside of en-US - //auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L""); - //if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft") - //{ - // newEndpointDefinition->ManufacturerName = manufacturer; - //} + auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L""); + if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft") + { + newParentDeviceDefinition->ManufacturerName = manufacturer; + } + + m_allParentDeviceDefinitions[newParentDeviceDefinition->DeviceInstanceId] = newParentDeviceDefinition; + parentDeviceDefinition = newParentDeviceDefinition; + + return S_OK; +} + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::FindCurrentMaxEndpointIndexForParentDevice( + std::shared_ptr parentDeviceDefinition, + uint16_t& currentMaxIndex) +{ + auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceDefinition->DeviceInstanceId); + + int32_t maxIndex{ -1 }; + bool found{ false }; + + auto activatedLock = m_activatedEndpointDefinitionsLock.lock(); + auto pendingLock = m_pendingEndpointDefinitionsLock.lock(); + + // look through all pending and activated endpoints and find the max. + // If the max is 0 or greater, return it and set S_OK. + // if no endpoints found, return E_NOTFOUND so the calling code + // knows that the 0 is not in use + + for (auto const& ep : m_activatedEndpointDefinitions) + { + if (ep.second->ParentDeviceInstanceId == cleanParentDeviceInstanceId) + { + maxIndex++; + found = true; + } + } + + for (auto const& ep : m_pendingEndpointDefinitions) + { + if (ep->ParentDeviceInstanceId == cleanParentDeviceInstanceId) + { + maxIndex++; + found = true; + } + } + - //// default hash is the device id. - //std::hash hasher; - //std::wstring hash; - //hash = std::to_wstring(hasher(newEndpointDefinition->ParentDeviceInstanceId)); + if (found) + { + currentMaxIndex = maxIndex; + return S_OK; + } + else + { + return E_NOTFOUND; + } - //newEndpointDefinition->EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; +} + + +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForFilterDevice( + DeviceInformation filterDevice, + std::shared_ptr& endpointDefinition +) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); + + std::shared_ptr parentDeviceDefinition{ nullptr }; + + // this function locks the parent device list for the duration of the call + RETURN_IF_FAILED(FindOrCreateParentDeviceDefinitionForFilterDevice( + filterDevice, + parentDeviceDefinition)); + + RETURN_HR_IF_NULL(E_POINTER, parentDeviceDefinition); + + // at this point, we have a complete parent device definition, so we can find or create endpoints for it + + + + // TODO: See if we already have an endpoint with space for the number of groups we're going to add + + + + // create a new endpoint + auto newEndpointDefinition = std::make_shared(); + RETURN_HR_IF_NULL(E_POINTER, parentDeviceDefinition); + + + // We need to ensure each endpoint has a unique id. They can't all use the ParentDeviceInstanceId as the + // instance id because now some devices will have multiple endpoints. Instead, we need to add a suffix to + // that. We need this to be deterministic and not just a random GUID/number, so that device ids have a + // chance to match up when next enumerated after a restart or connect/disconnect. + + auto parentLock = m_allParentDeviceDefinitionsLock.lock(); + + uint16_t endpointIndexForThisParent{ 0 }; + if (SUCCEEDED(FindCurrentMaxEndpointIndexForParentDevice(parentDeviceDefinition, endpointIndexForThisParent))) + { + // increment the number here + endpointIndexForThisParent++; + } + + + newEndpointDefinition->EndpointIndexForThisParentDevice = endpointIndexForThisParent; + + + // default hash is the device id. + std::hash hasher; + std::wstring hash; + hash = std::to_wstring(hasher(parentDeviceDefinition->DeviceInstanceId)); + + if (endpointIndexForThisParent == 0) + { + newEndpointDefinition->EndpointName = parentDeviceDefinition->DeviceName; + newEndpointDefinition->EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; + } + else + { + // pad the string with "0" characters to the left of the number, up to 3 places total. + // we +1 so the second one is _002 and not _001 + newEndpointDefinition->EndpointDeviceInstanceId = std::format(L"{0}{1}_{2:0>3}", TRANSPORT_INSTANCE_ID_PREFIX, hash, endpointIndexForThisParent + 1); + + // add the name disambiguator to the endpoint. We +1 to the index for the same reasons as above. + newEndpointDefinition->EndpointName = std::format(L"{0} ({1})", parentDeviceDefinition->DeviceName, endpointIndexForThisParent + 1); + //newEndpointDefinition->EndpointName = std::format(L"{1} - {0}", parentDeviceDefinition->DeviceName, endpointIndexForThisParent + 1); + } TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1583,6 +1797,26 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( TraceLoggingWideString(filterDevice.Id().c_str(), "added interface") ); + + // 1. Get the list of pins that are the right category for us to try to activate + // 2. Loop through and build final list of all MIDI 1.0 source and destination pins + // 3. Do we already have an activated endpoint for this device? + // 3.1 If we do, then see if it has room for these pins. + // 3.1.1 If it has room, then add these pins to the endpoint + // 3.1.2 If it doesn't have room, then build a new endpoint for this device + // and add that endpoint to the pending endpoints list + // 3.2 If we do not have an activated endpoint, see if we have a pending endpoint + // + + + + + + + + + + std::wstring transportCode(TRANSPORT_CODE); // Wrapper opens the handle internally. @@ -1902,7 +2136,7 @@ winrt::hstring CMidi2KSAggregateMidiEndpointManager2::FindMatchingInstantiatedEn { criteria.Normalize(); - for (auto const& def : m_availableEndpointDefinitions) + for (auto const& def : m_activatedEndpointDefinitions) { WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h index 35cbc3eae..b818b40ca 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h @@ -21,7 +21,6 @@ using namespace winrt::Windows::Devices::Enumeration; struct KsAggregateEndpointMidiPinDefinition2 { - //std::wstring KSDriverSuppliedName; std::wstring FilterDeviceId; // this is also the value needed by WinMM for DRV_QUERYDEVICEINTERFACE std::wstring FilterName; @@ -33,12 +32,12 @@ struct KsAggregateEndpointMidiPinDefinition2 uint8_t GroupIndex{ 0 }; uint8_t PortIndexWithinThisFilterAndDirection{ 0 }; // not always the same as the group index. Example: MOTU Express 128 with separate filter for each in/out pair - - // internal::Midi1PortNaming::Midi1PortNameEntry PortNames; }; struct KsAggregateEndpointDefinition2 { + std::wstring ParentDeviceInstanceId{}; + std::wstring EndpointDeviceId{}; std::wstring EndpointName{}; @@ -47,6 +46,8 @@ struct KsAggregateEndpointDefinition2 std::vector> MidiPins{ }; WindowsMidiServicesNamingLib::MidiEndpointNameTable EndpointNameTable{ }; + + uint16_t EndpointIndexForThisParentDevice{ 0 }; }; @@ -57,25 +58,14 @@ class KsAggregateParentDeviceDefinition2 std::wstring DeviceInstanceId{}; std::wstring DriverSuppliedDeviceName{}; // value from registry. Required for WinMM classic naming + std::wstring NameDisambiguatorPrefix{}; // for when there are multiple of the same device attached + + uint16_t VID{ 0 }; // USB-only uint16_t PID{ 0 }; // USB-only std::wstring SerialNumber{}; std::wstring ManufacturerName{}; - - std::vector> Endpoints{ }; // most devices will have just one endpoint, but virtual can have > 1 - - - // This will add a pin, and create new endpoints as needed, assign the group index, etc. - // it will also update the name table given the info we have - HRESULT AddPin(_In_ std::shared_ptr pin); - - //HRESULT RemoveFilter(_In_ std::wstring filterId); - - - -private: - }; @@ -93,23 +83,11 @@ class CMidi2KSAggregateMidiEndpointManager2 : winrt::hstring FindMatchingInstantiatedEndpoint(_In_ WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria& criteria); private: + DWORD m_individualInterfaceEnumTimeoutMS{ DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; + wil::com_ptr_nothrow m_midiDeviceManager; wil::com_ptr_nothrow m_midiProtocolManager; - wil::critical_section m_availableEndpointDefinitionsLock; - wil::critical_section m_pendingEndpointDefinitionsLock; - - HRESULT ParseParentIdIntoVidPidSerial( - _In_ winrt::hstring systemDevicesParentValue, - _In_ std::shared_ptr& parentDevice); - - HRESULT GetPinName(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ std::wstring& pinName); - HRESULT GetPinDataFlow(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ KSPIN_DATAFLOW& dataFlow); - - STDMETHOD(CreateMidiUmpEndpoint)( - _In_ std::shared_ptr masterEndpointDefinition, - _In_ std::shared_ptr parentDevice); - HRESULT OnFilterDeviceInterfaceAdded(_In_ DeviceWatcher, _In_ DeviceInformation); HRESULT OnFilterDeviceInterfaceRemoved(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); HRESULT OnFilterDeviceInterfaceUpdated(_In_ DeviceWatcher, _In_ DeviceInformationUpdate); @@ -117,25 +95,57 @@ class CMidi2KSAggregateMidiEndpointManager2 : HRESULT OnEnumerationCompleted(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); HRESULT OnDeviceWatcherStopped(_In_ DeviceWatcher, _In_ winrt::Windows::Foundation::IInspectable); + + // key is parent device instance id. These don't get "activated" so we keep a single list + std::map> m_allParentDeviceDefinitions; + wil::critical_section m_allParentDeviceDefinitionsLock; + + // these are all endpoints which have not yet been activated + std::vector> m_pendingEndpointDefinitions; + wil::critical_section m_pendingEndpointDefinitionsLock; + // key is parent device instance id - std::map> m_availableEndpointDefinitions; - std::vector> m_pendingEndpointDefinitions; + std::map> m_activatedEndpointDefinitions; + wil::critical_section m_activatedEndpointDefinitionsLock; + + + + bool ActiveKSAEndpointForDeviceExists( + _In_ std::wstring deviceInstanceId); + + HRESULT ParseParentIdIntoVidPidSerial( + _In_ std::wstring systemDevicesParentValue, + _In_ std::shared_ptr& parentDevice); HRESULT FindActivatedEndpointDefinitionForFilterDevice( - _In_ std::wstring parentDeviceInstanceId, + _In_ std::wstring filterDeviceId, _In_ std::shared_ptr&); + + HRESULT FindOrCreateParentDeviceDefinitionForFilterDevice( + DeviceInformation filterDevice, + std::shared_ptr& parentDeviceDefinition); + HRESULT FindOrCreatePendingEndpointDefinitionForFilterDevice( _In_ DeviceInformation, _In_ std::shared_ptr&); + + HRESULT FindCurrentMaxEndpointIndexForParentDevice( + _In_ std::shared_ptr parentDeviceDefinition, + _In_ uint16_t& currentMaxIndex); + + + HRESULT GetPinName(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ std::wstring& pinName); + HRESULT GetPinDataFlow(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ KSPIN_DATAFLOW& dataFlow); + HRESULT GetMidi1FilterPins( _In_ DeviceInformation, _In_ std::vector&); - bool ActiveKSAEndpointForDeviceExists( - _In_ std::wstring deviceInstanceId); + HRESULT GetKSDriverSuppliedName(_In_ HANDLE hFilter, _Inout_ std::wstring& name); + //HRESULT IncrementAndGetNextGroupIndex( // _In_ std::shared_ptr definition, @@ -156,14 +166,22 @@ class CMidi2KSAggregateMidiEndpointManager2 : _In_ std::shared_ptr masterEndpointDefinition, _In_ std::shared_ptr customProperties); - wil::unique_event_nothrow m_endpointCreationThreadWakeup; - std::jthread m_endpointCreationThread; - void EndpointCreationThreadWorker(_In_ std::stop_token token); - HRESULT UpdateExistingMidiUmpEndpointWithFilterChanges( + // these two functions actually update the software devices in Windows + + HRESULT DeviceCreateMidiUmpEndpoint( _In_ std::shared_ptr masterEndpointDefinition, _In_ std::shared_ptr parentDevice); + HRESULT DeviceUpdateExistingMidiUmpEndpointWithFilterChanges( + _In_ std::shared_ptr masterEndpointDefinition, + _In_ std::shared_ptr parentDevice); + + + wil::unique_event_nothrow m_endpointCreationThreadWakeup; + std::jthread m_endpointCreationThread; + void EndpointCreationThreadWorker(_In_ std::stop_token token); + DeviceWatcher m_watcher{0}; winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::Added_revoker m_DeviceAdded; @@ -173,7 +191,5 @@ class CMidi2KSAggregateMidiEndpointManager2 : winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher::EnumerationCompleted_revoker m_DeviceEnumerationCompleted; wil::unique_event m_EnumerationCompleted{wil::EventOptions::None}; - HRESULT GetKSDriverSuppliedName(_In_ HANDLE hFilter, _Inout_ std::wstring& name); - DWORD m_individualInterfaceEnumTimeoutMS { DEFAULT_KSA_INTERFACE_ENUM_TIMEOUT_MS }; }; diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.rc b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.rc index 232b8b0cc..70183e097 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.rc +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.rc @@ -53,8 +53,8 @@ END // VS_VERSION_INFO VERSIONINFO - FILEVERSION 1,0,15,0 - PRODUCTVERSION 1,0,15,0 + FILEVERSION 1,0,16,0 + PRODUCTVERSION 1,0,16,0 FILEFLAGSMASK VS_FFI_FILEFLAGSMASK #ifdef _DEBUG FILEFLAGS VS_FF_DEBUG @@ -71,12 +71,12 @@ BEGIN BEGIN VALUE "CompanyName", "Microsoft Corporation" VALUE "FileDescription", "Windows MIDI Services KSA Transport Plugin" - VALUE "FileVersion", "1.0.15.0" + VALUE "FileVersion", "1.0.16.0" VALUE "LegalCopyright", "Copyright (c) Microsoft Corporation. All rights reserved." VALUE "InternalName", "Midi2.KSAggregateTransport.dll" VALUE "OriginalFilename", "Midi2.KSAggregateTransport.dll" VALUE "ProductName", "Microsoft Windows MIDI Services" - VALUE "ProductVersion", "1.0.15.0" + VALUE "ProductVersion", "1.0.16.0" VALUE "OLESelfRegister", "" END END diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj index 7cbd5b341..145bc00c6 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj @@ -381,7 +381,7 @@ - + diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters index cabf18a58..8dcaaf342 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.vcxproj.filters @@ -106,7 +106,7 @@ Header Files - + Header Files From 83626cd6a518a93779c61bde9ed085b457834431 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Mon, 16 Feb 2026 20:00:00 -0500 Subject: [PATCH 13/18] Fix missing Feature include directory --- src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj b/src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj index f205a17a3..377f6c9bb 100644 --- a/src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj +++ b/src/api/Drivers/USBMIDI2/Driver/USBMidi2.vcxproj @@ -79,24 +79,28 @@ USBMidi2 $(SolutionDir)VSFiles\$(Platform)\$(Configuration)\ $(SolutionDir)VSFiles\intermediate\$(ProjectName)\$(Platform)\$(Configuration)\ + $(IncludePath);$(KMDF_INC_PATH)$(KMDF_VER_PATH);$(SolutionDir)\inc DbgengKernelDebugger USBMidi2 $(SolutionDir)VSFiles\$(Platform)\$(Configuration)\ $(SolutionDir)VSFiles\intermediate\$(ProjectName)\$(Platform)\$(Configuration)\ + $(IncludePath);$(KMDF_INC_PATH)$(KMDF_VER_PATH);$(SolutionDir)\inc DbgengKernelDebugger USBMidi2 $(SolutionDir)VSFiles\$(Platform)\$(Configuration)\ $(SolutionDir)VSFiles\intermediate\$(ProjectName)\$(Platform)\$(Configuration)\ + $(IncludePath);$(KMDF_INC_PATH)$(KMDF_VER_PATH);$(SolutionDir)\inc DbgengKernelDebugger USBMidi2 $(SolutionDir)VSFiles\$(Platform)\$(Configuration)\ $(SolutionDir)VSFiles\intermediate\$(ProjectName)\$(Platform)\$(Configuration)\ + $(IncludePath);$(KMDF_INC_PATH)$(KMDF_VER_PATH);$(SolutionDir)\inc false From 7e27bb624942bf1a370984a9faf7bb07de75190a Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Mon, 16 Feb 2026 20:00:43 -0500 Subject: [PATCH 14/18] More changes to support loopMIDI etc. --- .../Midi2.KSAggregateMidiEndpointManager2.cpp | 761 ++++++++---------- .../Midi2.KSAggregateMidiEndpointManager2.h | 40 +- 2 files changed, 364 insertions(+), 437 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp index 3a5f2813d..1f669c4ec 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp @@ -13,6 +13,7 @@ #include // for the string stream in parsing of VID/PID/Serial from parent id #include // for getline for string parsing of VID/PID/Serial from parent id +#include "Feature_Servicing_MIDI2FilterCreations.h" using namespace wil; using namespace winrt::Windows::Devices::Enumeration; @@ -22,114 +23,6 @@ using namespace Microsoft::WRL; using namespace Microsoft::WRL::Wrappers; #define INITIAL_ENUMERATION_TIMEOUT_MS 10000 - - - - -_Use_decl_annotations_ -HRESULT -KsAggregateParentDeviceDefinition2::AddPin( - std::shared_ptr pin) -{ - UNREFERENCED_PARAMETER(pin); - - // TODO - - // do we have any endpoints with this filter? - // if so, check to see if it has space for more pins - - // if no existing endpoint with this filter, then check to see if we have any - // endpoints at all we can add this to - - // if no endpoints at all, or if all other endpoints are filled, create - // a new endpoint - - - // Add this pin to the endpoint - - - - return S_OK; -} - - -bool KsAggregateParentDeviceDefinition2::HasPendingEndpointDeviceChanges() -{ - for (auto& endpoint : Endpoints) - { - if (endpoint->HasPendingDeviceChanges) - { - return true; - } - } - - return false; -} - -_Use_decl_annotations_ -HRESULT -KsAggregateParentDeviceDefinition2::FindEndpointForFilter( - std::wstring filterDeviceId, - std::shared_ptr& foundEndpoint) -{ - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDeviceId.c_str(), "filter device id") - ); - - auto cleanFilterDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(filterDeviceId); - - for (auto const& endpoint : Endpoints) - { - for (auto const& pin : endpoint->MidiPins) - { - if (internal::NormalizeEndpointInterfaceIdWStringCopy(pin->FilterDeviceId) == cleanFilterDeviceId) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Matching Endpoint found", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(cleanFilterDeviceId.c_str(), "filter device id") - ); - - foundEndpoint = endpoint; - - return S_OK; - } - } - } - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(cleanFilterDeviceId.c_str(), "filter device id") - ); - - foundEndpoint = nullptr; - - return E_NOTFOUND; -} - - - - - - - - - _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::Initialize( @@ -448,12 +341,14 @@ CMidi2KSAggregateMidiEndpointManager2::BuildPinsAndGroupTerminalBlocksPropertyDa _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( - std::shared_ptr masterEndpointDefinition, - std::shared_ptr parentDevice + std::shared_ptr endpointDefinition ) { - RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); - RETURN_HR_IF_NULL(E_INVALIDARG, parentDevice); + RETURN_HR_IF_NULL(E_INVALIDARG, endpointDefinition); + + std::shared_ptr parentDevice { nullptr }; + RETURN_IF_FAILED(FindExistingParentDeviceDefinitionForEndpoint(endpointDefinition, parentDevice)); + RETURN_HR_IF_NULL(E_UNEXPECTED, parentDevice); TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -462,22 +357,24 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "endpoint name"), + TraceLoggingWideString(parentDevice->DeviceName.c_str(), "device name") ); DEVPROP_BOOLEAN devPropTrue = DEVPROP_TRUE; - // we require at least one valid pin - RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition->MidiPins.size() < 1); + // we require at least one valid pin, and no more than 32 total pins (16 in, 16 out, max) + RETURN_HR_IF(E_INVALIDARG, endpointDefinition->MidiPins.size() < 1); + RETURN_HR_IF(E_INVALIDARG, endpointDefinition->MidiPins.size() > 32); std::vector interfaceDevProperties; MIDIENDPOINTCOMMONPROPERTIES commonProperties{}; commonProperties.TransportId = TRANSPORT_LAYER_GUID; commonProperties.EndpointDeviceType = MidiEndpointDeviceType_Normal; - commonProperties.FriendlyName = masterEndpointDefinition->EndpointName.c_str(); + commonProperties.FriendlyName = endpointDefinition->EndpointName.c_str(); commonProperties.TransportCode = TRANSPORT_CODE; - commonProperties.EndpointName = masterEndpointDefinition->EndpointName.c_str(); + commonProperties.EndpointName = endpointDefinition->EndpointName.c_str(); commonProperties.EndpointDescription = nullptr; commonProperties.CustomEndpointName = nullptr; commonProperties.CustomEndpointDescription = nullptr; @@ -498,7 +395,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( std::vector nameTablePropertyData; RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData( - masterEndpointDefinition, + endpointDefinition, pinMapPropertyData, groupTerminalBlocks)); @@ -521,6 +418,8 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( else { // write empty data + interfaceDevProperties.push_back({ { PKEY_MIDI_GroupTerminalBlocks, DEVPROP_STORE_SYSTEM, nullptr }, + DEVPROP_TYPE_EMPTY, 0, nullptr }); } @@ -528,16 +427,16 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( // =============================================================================== WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{}; - matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->EndpointDeviceInstanceId); + matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->EndpointDeviceInstanceId); matchCriteria.UsbVendorId = parentDevice->VID; matchCriteria.UsbProductId = parentDevice->PID; matchCriteria.UsbSerialNumber = parentDevice->SerialNumber; - matchCriteria.TransportSuppliedEndpointName = masterEndpointDefinition->EndpointName; + matchCriteria.TransportSuppliedEndpointName = endpointDefinition->EndpointName; auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria); // rebuild the name table, using the custom properties if present - RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(endpointDefinition, customProperties)); std::wstring customName{ }; std::wstring customDescription{ }; @@ -550,7 +449,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingWideString(endpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"), TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count") ); @@ -579,14 +478,14 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) + TraceLoggingWideString(endpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) ); } // Write Name table property, folding in the custom names we discovered earlier // =============================================================================================== - RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); - masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(endpointDefinition, customProperties)); + endpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); // Write USB VID/PID Data @@ -624,9 +523,9 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( SW_DEVICE_CREATE_INFO createInfo{ }; createInfo.cbSize = sizeof(createInfo); - createInfo.pszInstanceId = masterEndpointDefinition->EndpointDeviceInstanceId.c_str(); + createInfo.pszInstanceId = endpointDefinition->EndpointDeviceInstanceId.c_str(); createInfo.CapabilityFlags = SWDeviceCapabilitiesNone; - createInfo.pszDeviceDescription = masterEndpointDefinition->EndpointName.c_str(); + createInfo.pszDeviceDescription = endpointDefinition->EndpointName.c_str(); // Call the device manager and finish the creation @@ -640,7 +539,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Activating endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name") ); // set to true if we only want to create UMP endpoints @@ -669,19 +568,19 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Aggregate UMP endpoint created", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name"), TraceLoggingWideString(newDeviceInterfaceId.get(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) ); // return new device interface id - masterEndpointDefinition->EndpointDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(std::wstring{ newDeviceInterfaceId.get() }); + endpointDefinition->EndpointDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(std::wstring{ newDeviceInterfaceId.get() }); auto lock = m_activatedEndpointDefinitionsLock.lock(); // Add to internal endpoint manager m_activatedEndpointDefinitions.insert_or_assign( internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice->DeviceInstanceId), - masterEndpointDefinition); + endpointDefinition); return swdCreationResult; } @@ -694,7 +593,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( TraceLoggingLevel(WINEVENT_LEVEL_ERROR), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Aggregate UMP endpoint creation failed", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name"), TraceLoggingHResult(swdCreationResult, MIDI_TRACE_EVENT_HRESULT_FIELD) ); @@ -707,12 +606,15 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceCreateMidiUmpEndpoint( _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFilterChanges( - std::shared_ptr masterEndpointDefinition, - std::shared_ptr parentDevice + std::shared_ptr endpointDefinition ) { - RETURN_HR_IF_NULL(E_INVALIDARG, masterEndpointDefinition); - RETURN_HR_IF_NULL(E_INVALIDARG, parentDevice); + RETURN_HR_IF_NULL(E_INVALIDARG, endpointDefinition); + + std::shared_ptr parentDevice{ nullptr }; + RETURN_IF_FAILED(FindExistingParentDeviceDefinitionForEndpoint(endpointDefinition, parentDevice)); + RETURN_HR_IF_NULL(E_UNEXPECTED, parentDevice); + TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -721,11 +623,11 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFi TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name") + TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name") ); // we require at least one valid pin - RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition->MidiPins.size() < 1); + RETURN_HR_IF(E_INVALIDARG, endpointDefinition->MidiPins.size() < 1); std::vector interfaceDevProperties{ }; @@ -736,7 +638,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFi // update the pin map to have all the existing pins // plus the new pins. Update Group Terminal Blocks at the same time. RETURN_IF_FAILED(BuildPinsAndGroupTerminalBlocksPropertyData( - masterEndpointDefinition, + endpointDefinition, pinMapPropertyData, groupTerminalBlocks)); @@ -767,11 +669,11 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFi // =============================================================================== WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria matchCriteria{}; - matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(masterEndpointDefinition->EndpointDeviceInstanceId); + matchCriteria.DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->EndpointDeviceInstanceId); matchCriteria.UsbVendorId = parentDevice->VID; matchCriteria.UsbProductId = parentDevice->PID; matchCriteria.UsbSerialNumber = parentDevice->SerialNumber; - matchCriteria.TransportSuppliedEndpointName = masterEndpointDefinition->EndpointName; + matchCriteria.TransportSuppliedEndpointName = endpointDefinition->EndpointName; auto customProperties = TransportState::Current().GetConfigurationManager()->CustomPropertiesCache()->GetProperties(matchCriteria); @@ -786,7 +688,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFi TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Found custom properties cached for this endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), + TraceLoggingWideString(endpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD), TraceLoggingUInt32(static_cast(customProperties->Midi1Sources.size()), "MIDI 1 Source count"), TraceLoggingUInt32(static_cast(customProperties->Midi1Destinations.size()), "MIDI 1 Destination count") ); @@ -803,23 +705,23 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFi TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"No cached custom properties for this endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) + TraceLoggingWideString(endpointDefinition->EndpointDeviceInstanceId.c_str(), MIDI_TRACE_EVENT_DEVICE_INSTANCE_ID_FIELD) ); } // store the property data for the name table - masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + endpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); // Write Name table property, folding in the custom names we discovered earlier // =============================================================================================== - RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(masterEndpointDefinition, customProperties)); - masterEndpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); + RETURN_IF_FAILED(UpdateNameTableWithCustomProperties(endpointDefinition, customProperties)); + endpointDefinition->EndpointNameTable.WriteProperties(interfaceDevProperties); HRESULT updateResult{}; LOG_IF_FAILED(updateResult = m_midiDeviceManager->UpdateEndpointProperties( - masterEndpointDefinition->EndpointDeviceId.c_str(), + endpointDefinition->EndpointDeviceId.c_str(), static_cast(interfaceDevProperties.size()), interfaceDevProperties.data() )); @@ -834,7 +736,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFi TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Aggregate UMP endpoint updated with new filter", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointDeviceId.c_str(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) + TraceLoggingWideString(endpointDefinition->EndpointDeviceId.c_str(), MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD) ); auto lock = m_activatedEndpointDefinitionsLock.lock(); @@ -842,7 +744,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFi // Add to internal endpoint manager m_activatedEndpointDefinitions.insert_or_assign( internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice->DeviceInstanceId), - masterEndpointDefinition); + endpointDefinition); } else @@ -854,7 +756,7 @@ CMidi2KSAggregateMidiEndpointManager2::DeviceUpdateExistingMidiUmpEndpointWithFi TraceLoggingLevel(WINEVENT_LEVEL_ERROR), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Aggregate UMP endpoint update failed", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(masterEndpointDefinition->EndpointName.c_str(), "name"), + TraceLoggingWideString(endpointDefinition->EndpointName.c_str(), "name"), TraceLoggingHResult(updateResult, MIDI_TRACE_EVENT_HRESULT_FIELD) ); } @@ -1103,9 +1005,11 @@ CMidi2KSAggregateMidiEndpointManager2::FindActivatedEndpointDefinitionForFilterD TraceLoggingWideString(filterDeviceId.c_str(), "filter device id") ); - for (auto const& entry : m_activatedEndpointDefinitions) + auto cleanFilterDeviceId = internal::NormalizeEndpointInterfaceIdWStringCopy(filterDeviceId); + + for (auto const& endpoint : m_activatedEndpointDefinitions) { - if (SUCCEEDED(entry.second->FindEndpointForFilter(filterDeviceId, endpointDefinition))) + for (auto const& pin: endpoint.second->MidiPins) { TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1114,10 +1018,14 @@ CMidi2KSAggregateMidiEndpointManager2::FindActivatedEndpointDefinitionForFilterD TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"Matching Endpoint found", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDeviceId.c_str(), "filter device id") + TraceLoggingWideString(cleanFilterDeviceId.c_str(), "filter device id") ); - return S_OK; + if (internal::NormalizeEndpointInterfaceIdWStringCopy(pin->FilterDeviceId) == cleanFilterDeviceId) + { + endpointDefinition = endpoint.second; + return S_OK; + } } } @@ -1128,7 +1036,7 @@ CMidi2KSAggregateMidiEndpointManager2::FindActivatedEndpointDefinitionForFilterD TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDeviceId.c_str(), "filter device id") + TraceLoggingWideString(cleanFilterDeviceId.c_str(), "filter device id") ); endpointDefinition = nullptr; @@ -1137,10 +1045,29 @@ CMidi2KSAggregateMidiEndpointManager2::FindActivatedEndpointDefinitionForFilterD } +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::FindExistingParentDeviceDefinitionForEndpoint( + std::shared_ptr endpointDefinition, + std::shared_ptr& parentDeviceDefinition +) +{ + if (auto parent = m_allParentDeviceDefinitions.find(endpointDefinition->ParentDeviceInstanceId); parent != m_allParentDeviceDefinitions.end()) + { + parentDeviceDefinition = parent->second; + + return S_OK; + } + + return E_NOTFOUND; +} + + _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::FindOrCreateParentDeviceDefinitionForFilterDevice( DeviceInformation filterDevice, + KsHandleWrapper& filterDeviceHandleWrapper, std::shared_ptr& parentDeviceDefinition ) { @@ -1212,17 +1139,51 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreateParentDeviceDefinitionForFilt newParentDeviceDefinition->ManufacturerName = manufacturer; } + // Driver-supplied name. This is needed for WinMM backwards compatibility + std::wstring driverSuppliedName{}; + + // Using lamba function to prevent handle from dissapearing when being used. + // we don't log the HRESULT because this is not critical and will often fail, + // adding unnecessary noise to the error logs + filterDeviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetKSDriverSuppliedName(h, driverSuppliedName); + }); + + newParentDeviceDefinition->DriverSuppliedDeviceName = driverSuppliedName; + + + // Do we need to disambiguate this parent because another of the same device already exists? + + uint32_t currentMaxIndex { 0 }; + bool otherParentsWithSameNameExist{ false }; + + for (auto const& existingParent : m_allParentDeviceDefinitions) + { + if (existingParent.second->DeviceName == newParentDeviceDefinition->DeviceName) + { + currentMaxIndex = max(currentMaxIndex, existingParent.second->IndexOfDevicesWithThisSameName); + otherParentsWithSameNameExist = true; + } + } + + if (otherParentsWithSameNameExist) + { + parentDeviceDefinition->IndexOfDevicesWithThisSameName = currentMaxIndex + 1; + } + + m_allParentDeviceDefinitions[newParentDeviceDefinition->DeviceInstanceId] = newParentDeviceDefinition; parentDeviceDefinition = newParentDeviceDefinition; return S_OK; } + _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::FindCurrentMaxEndpointIndexForParentDevice( std::shared_ptr parentDeviceDefinition, - uint16_t& currentMaxIndex) + uint32_t& currentMaxIndex) { auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceDefinition->DeviceInstanceId); @@ -1258,7 +1219,7 @@ CMidi2KSAggregateMidiEndpointManager2::FindCurrentMaxEndpointIndexForParentDevic if (found) { - currentMaxIndex = maxIndex; + currentMaxIndex = static_cast(maxIndex); return S_OK; } else @@ -1273,6 +1234,7 @@ _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForFilterDevice( DeviceInformation filterDevice, + KsHandleWrapper& filterDeviceHandleWrapper, std::shared_ptr& endpointDefinition ) { @@ -1282,7 +1244,8 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + TraceLoggingWideString(L"Enter.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") ); std::shared_ptr parentDeviceDefinition{ nullptr }; @@ -1290,15 +1253,20 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF // this function locks the parent device list for the duration of the call RETURN_IF_FAILED(FindOrCreateParentDeviceDefinitionForFilterDevice( filterDevice, - parentDeviceDefinition)); + filterDeviceHandleWrapper, + parentDeviceDefinition + )); RETURN_HR_IF_NULL(E_POINTER, parentDeviceDefinition); - // at this point, we have a complete parent device definition, so we can find or create endpoints for it + + // TODO: See if we already have an endpoint with space for the number of groups we're going to add + + + - // TODO: See if we already have an endpoint with space for the number of groups we're going to add @@ -1306,7 +1274,6 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF auto newEndpointDefinition = std::make_shared(); RETURN_HR_IF_NULL(E_POINTER, parentDeviceDefinition); - // We need to ensure each endpoint has a unique id. They can't all use the ParentDeviceInstanceId as the // instance id because now some devices will have multiple endpoints. Instead, we need to add a suffix to // that. We need this to be deterministic and not just a random GUID/number, so that device ids have a @@ -1314,17 +1281,15 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF auto parentLock = m_allParentDeviceDefinitionsLock.lock(); - uint16_t endpointIndexForThisParent{ 0 }; + uint32_t endpointIndexForThisParent{ 0 }; if (SUCCEEDED(FindCurrentMaxEndpointIndexForParentDevice(parentDeviceDefinition, endpointIndexForThisParent))) { // increment the number here endpointIndexForThisParent++; } - newEndpointDefinition->EndpointIndexForThisParentDevice = endpointIndexForThisParent; - // default hash is the device id. std::hash hasher; std::wstring hash; @@ -1362,26 +1327,28 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF return S_OK; } -//_Use_decl_annotations_ -//HRESULT -//CMidi2KSAggregateMidiEndpointManager::IncrementAndGetNextGroupIndex( -// std::shared_ptr definition, -// MidiFlow dataFlowFromUserPerspective, -// uint8_t& groupIndex) -//{ -// if (dataFlowFromUserPerspective == MidiFlow::MidiFlowIn) -// { -// definition->CurrentHighestMidiSourceGroupIndex++; -// groupIndex = definition->CurrentHighestMidiSourceGroupIndex; -// } -// else -// { -// definition->CurrentHighestMidiDestinationGroupIndex++; -// groupIndex = definition->CurrentHighestMidiDestinationGroupIndex; -// } -// -// return S_OK; -//} +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::IncrementAndGetNextGroupIndex( + std::shared_ptr definition, + MidiFlow dataFlowFromUserPerspective, + uint8_t& groupIndex) +{ + // the structure is initialized with -1 for the current group index in each direction + + if (dataFlowFromUserPerspective == MidiFlow::MidiFlowIn) + { + definition->CurrentHighestMidiSourceGroupIndex++; + groupIndex = definition->CurrentHighestMidiSourceGroupIndex; + } + else + { + definition->CurrentHighestMidiDestinationGroupIndex++; + groupIndex = definition->CurrentHighestMidiDestinationGroupIndex; + } + + return S_OK; +} #define MAX_THREAD_WORKER_WAIT_TIME_MS 20000 @@ -1474,7 +1441,7 @@ void CMidi2KSAggregateMidiEndpointManager2::EndpointCreationThreadWorker( m_pendingEndpointDefinitions.erase(m_pendingEndpointDefinitions.begin()); // create the endpoint - LOG_IF_FAILED(CreateMidiUmpEndpoint(ep)); + LOG_IF_FAILED(DeviceCreateMidiUmpEndpoint(ep)); } TraceLoggingWrite( @@ -1486,34 +1453,6 @@ void CMidi2KSAggregateMidiEndpointManager2::EndpointCreationThreadWorker( TraceLoggingWideString(L"EndpointCreationWorker: Processed all pending endpoint definitions.", MIDI_TRACE_EVENT_MESSAGE_FIELD) ); } - -#ifdef _DEBUG - else - { - if (m_pendingEndpointDefinitions.size() == 0) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Woken up, but no work to do. Pending count == 0.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - } - else - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"EndpointCreationWorker: Woken up, but thread is no longer signaled", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - } - } -#endif } } @@ -1533,10 +1472,12 @@ _Use_decl_annotations_ bool CMidi2KSAggregateMidiEndpointManager2::ActiveKSAEndpointForDeviceExists( _In_ std::wstring parentDeviceInstanceId) { - for (auto const& entry : m_availableEndpointDefinitions) + + auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str()); + + for (auto const& entry : m_activatedEndpointDefinitions) { - if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->DeviceInstanceId) == - internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str())) + if (internal::NormalizeDeviceInstanceIdWStringCopy(entry.second->ParentDeviceInstanceId) == cleanParentDeviceInstanceId) { return true; } @@ -1546,11 +1487,39 @@ bool CMidi2KSAggregateMidiEndpointManager2::ActiveKSAEndpointForDeviceExists( } +bool ShouldSkipOpeningKsPin(_In_ KsHandleWrapper& deviceHandleWrapper, _In_ UINT pinIndex) +{ + if (Feature_Servicing_MIDI2FilterCreations::IsEnabled()) + { + std::unique_ptr dataRanges; + ULONG dataRangesSize{ 0 }; + + // skip this pin if for some reason data ranges aren't valid + if (FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return RetrieveDataRanges(h, pinIndex, (PKSMULTIPLE_ITEM*)&dataRanges, &dataRangesSize); + }))) + { + return true; + } + + // Skip this pin if it supports cyclic ump, or if it doesn't support bytestream + if (SUCCEEDED(DataRangeSupportsTransport(dataRanges.get(), MidiTransport_CyclicUMP)) || + FAILED(DataRangeSupportsTransport(dataRanges.get(), MidiTransport_StandardByteStream))) + { + return true; + } + } + + return false; +} + + + _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins( DeviceInformation filterDevice, - std::vector& pinListToAddTo + std::vector>& pinListToAddTo ) { // Wrapper opens the handle internally. @@ -1588,123 +1557,90 @@ CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins( continue; } - // Duplicate the handle to safely pass it to another component or store it. - wil::unique_handle handleDupe(deviceHandleWrapper.GetHandle()); - RETURN_IF_NULL_ALLOC(handleDupe); - - // we try to open UMP only so we understand the device - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Checking for UMP pin. This will fallback error fail for non-UMP devices.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); - - KsHandleWrapper m_PinHandleWrapperUmp(filterDevice.Id().c_str(), pinIndex, MidiTransport_CyclicUMP, handleDupe.get()); - if (SUCCEEDED(m_PinHandleWrapperUmp.Open())) + if (ShouldSkipOpeningKsPin(deviceHandleWrapper, pinIndex)) { - // this is a UMP pin. The KS transport will handle it, so we skip it here. - // In the future, we may want to bail on the first UMP pin we find. - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Found UMP/MIDI2 pin. Skipping for this transport.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); - continue; } + // Duplicate the handle to safely pass it to another component or store it. + wil::unique_handle handleDupe(deviceHandleWrapper.GetHandle()); + RETURN_IF_NULL_ALLOC(handleDupe); - // try to open as a MIDI 1 bytestream pin - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Checking for MIDI 1 pin. This will fallback error fail for non-MIDI devices.", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); + KsHandleWrapper pinHandleWrapper( + filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get()); - KsHandleWrapper pinHandleWrapperMidi1(filterDevice.Id().c_str(), pinIndex, MidiTransport_StandardByteStream, handleDupe.get()); - if (SUCCEEDED(pinHandleWrapperMidi1.Open())) + if (SUCCEEDED(pinHandleWrapper.Open())) { - // this is a MIDI 1.0 byte format pin, so let's process it - KsAggregateEndpointMidiPinDefinition pinDefinition{ }; + auto pinDefinition = std::make_shared(); + RETURN_HR_IF_NULL(E_POINTER, pinDefinition); - pinDefinition.PinNumber = pinIndex; - pinDefinition.FilterDeviceId = std::wstring{ filterDevice.Id() }; - pinDefinition.FilterName = std::wstring{ filterDevice.Name() }; + //pinDefinition.KSDriverSuppliedName = driverSuppliedName; + pinDefinition->PinNumber = pinIndex; + pinDefinition->FilterDeviceId = std::wstring{ filterDevice.Id() }; + pinDefinition->FilterName = std::wstring{ filterDevice.Name() }; // find the name of the pin (supplied by iJack, and if not available, generated by the stack) std::wstring pinName{ }; + HRESULT pinNameHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { return GetPinName(h, pinIndex, pinName); }); if (SUCCEEDED(pinNameHr)) { - pinDefinition.PinName = pinName; - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Pin has name", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingWideString(pinDefinition.PinName.c_str(), "pin name") - ); + pinDefinition->PinName = pinName; } - // get the data flow so we know if this is a MIDI Input (Source) or a MIDI Output (Destination) + // get the data flow so we know if this is a MIDI Input or a MIDI Output KSPIN_DATAFLOW dataFlow = (KSPIN_DATAFLOW)0; - RETURN_IF_FAILED(deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + + HRESULT dataFlowHr = deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { return GetPinDataFlow(h, pinIndex, dataFlow); - })); + }); - if (dataFlow == KSPIN_DATAFLOW_IN) + if (SUCCEEDED(dataFlowHr)) { - // MIDI Out (input to device) - pinDefinition.PinDataFlow = MidiFlow::MidiFlowIn; - pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowOut; // opposite - pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter); + if (dataFlow == KSPIN_DATAFLOW_IN) + { + // MIDI Out (input to device) + pinDefinition->PinDataFlow = MidiFlow::MidiFlowIn; + pinDefinition->DataFlowFromUserPerspective = MidiFlow::MidiFlowOut; // opposite + pinDefinition->PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter); - midiOutputPinIndexForThisFilter++; + midiOutputPinIndexForThisFilter++; + } + else if (dataFlow == KSPIN_DATAFLOW_OUT) + { + // MIDI In (output from device) + pinDefinition->PinDataFlow = MidiFlow::MidiFlowOut; + pinDefinition->DataFlowFromUserPerspective = MidiFlow::MidiFlowIn; // opposite + pinDefinition->PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter); + + midiInputPinIndexForThisFilter++; + } + + pinListToAddTo.push_back(pinDefinition); } - else if (dataFlow == KSPIN_DATAFLOW_OUT) + else { - // MIDI In (output from device) - pinDefinition.PinDataFlow = MidiFlow::MidiFlowOut; - pinDefinition.DataFlowFromUserPerspective = MidiFlow::MidiFlowIn; // opposite - pinDefinition.PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter); - - midiInputPinIndexForThisFilter++; + // this is a failure condition. Move on to next pin + LOG_IF_FAILED(dataFlowHr); + continue; } - - pinListToAddTo.push_back(pinDefinition); - - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"MIDI 1.0 pin added", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") - ); } } + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"MIDI 1.0 pins enumerated", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), + TraceLoggingUInt32(static_cast(pinListToAddTo.size()), "Total size of pin list including new pins.") + ); + return S_OK; } @@ -1731,7 +1667,11 @@ CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( // Figure out the group index for the pin. This needs the context of the // entire device. Failure to get the group index is fatal -// RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); + RETURN_IF_FAILED(IncrementAndGetNextGroupIndex( + endpointDefinition, + pinDefinition->DataFlowFromUserPerspective, + pinDefinition->GroupIndex + )); TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), @@ -1746,6 +1686,9 @@ CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") ); + // guard against errors resulting in invalid group indexes + RETURN_HR_IF(E_UNEXPECTED, !IS_VALID_GROUP_INDEX(pinDefinition->GroupIndex)); + std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update // Build the name table entry for this individual pin @@ -1764,22 +1707,6 @@ CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( } - - -//HRESULT -//PopulatePinKSDataFormats(HANDLE filterHandle/*, Some_vector_of_pin_format_structs*/) -//{ -// //Try this, it should be a fairly easy thing to add to your change. -// // retrieve the : -// //KSPROPSETID_Pin, -// // KSPROPERTY_PIN_DATARANGES, -// -// // limit to pins with(pKsDataFormat->MajorFormat == KSDATAFORMAT_TYPE_MUSIC) -// // -// // Retrieval is going to follow the same ksmultipleitemp pattern as KSPROPERTY_MIDI2_GROUP_TERMINAL_BLOCKS -//} - - _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( @@ -1809,26 +1736,16 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( // - - - - - - - - std::wstring transportCode(TRANSPORT_CODE); // Wrapper opens the handle internally. KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); RETURN_IF_FAILED(deviceHandleWrapper.Open()); - std::shared_ptr endpointDefinition{ nullptr }; - // get all the MIDI 1 pins. These are only partially processed because some things // like group index and naming require the full context of all filters/pins on the - // parent device. - std::vector pinList{ }; + // parent device. We want to get these before we try creating parents or endpoints + std::vector> pinList{ }; RETURN_IF_FAILED(GetMidi1FilterPins(filterDevice, pinList)); if (pinList.size() == 0) @@ -1845,24 +1762,61 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( return S_OK; } + else if (pinList.size() > 32) + { + // we don't support more than 32 pins, 16 in, 16 out, on a single endpoint. + // We also don't support splitting a filter across more than one endpoint + // so we can only enumerate the first 16 in each direction + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_ERROR, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Too many MIDI pins for this filter. Maximum of 16 in and 16 out allowed per KS filter.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id") + ); + + RETURN_IF_FAILED(E_FAIL); + } + + + + + // We have MIDI 1.0 pins to process, so we'll need to find or create a parent device + // and also find or create an endpoint under that parent, which has sufficient room + // for these pins. + + + // =================================================================== + // Find or create a parent device definition auto parentInstanceId = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceInstanceId", filterDevice, L""); RETURN_HR_IF(E_FAIL, parentInstanceId.empty()); - // Driver-supplied name. This is needed for WinMM backwards compatibility - std::wstring driverSuppliedName{}; + std::shared_ptr parentDeviceDefinition{ nullptr }; + + RETURN_IF_FAILED(FindOrCreateParentDeviceDefinitionForFilterDevice( + filterDevice, + deviceHandleWrapper, + parentDeviceDefinition + )); - // Using lamba function to prevent handle from dissapearing when being used. - // we don't log the HRESULT because this is not critical and will often fail, - // adding unnecessary noise to the error logs - deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetKSDriverSuppliedName(h, driverSuppliedName); - }); + // =================================================================== + // Find or create the endpoint + + + std::shared_ptr endpointDefinition{ nullptr }; // check to see if we already have an *activated* endpoint for this filter if (ActiveKSAEndpointForDeviceExists(parentInstanceId.c_str())) { + // TODO: If the existing endpoint definition doesn't have room for this set of pins, we need to create a new one + + + TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), MIDI_TRACE_EVENT_VERBOSE, @@ -1877,14 +1831,14 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( std::shared_ptr existingActivatedEndpointDefinition { nullptr }; // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindActivatedMasterEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); + RETURN_IF_FAILED(FindActivatedEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); // add our new pins into the existing endpoint definition existingActivatedEndpointDefinition->MidiPins.insert(existingActivatedEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); - RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), driverSuppliedName, existingActivatedEndpointDefinition)); + RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), parentDeviceDefinition->DriverSuppliedDeviceName, existingActivatedEndpointDefinition)); - RETURN_IF_FAILED(UpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition)); + RETURN_IF_FAILED(DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition)); return S_OK; } @@ -1902,7 +1856,6 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( ); } - // if the endpointDefinition is null, that means we haven't found an existing // activated endpoint definition we need to use, and so we proceed to check // for an existing pending endpoint definition. If found, it's used. If not @@ -1911,74 +1864,23 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( if (endpointDefinition == nullptr) { // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindOrCreatePendingMasterEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); + RETURN_IF_FAILED(FindOrCreatePendingEndpointDefinitionForFilterDevice(filterDevice, deviceHandleWrapper, endpointDefinition)); RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); - // add our new pins into the existing endpoint definition - endpointDefinition->MidiPins.insert(endpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); - pinList.clear(); // just make sure we don't use this one, accidentally - } - -#ifdef _DEBUG - if (!driverSuppliedName.empty()) - { - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Driver-supplied name found", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingWideString(driverSuppliedName.c_str(), "driver-supplied name") - ); - } -#endif - // At this point, we need to have *all* the pins for the endpoint, not just this filter - for (auto& pinDefinition : endpointDefinition->MidiPins) - { - if (internal::NormalizeDeviceInstanceIdWStringCopy(pinDefinition.FilterDeviceId) != - internal::NormalizeDeviceInstanceIdWStringCopy(filterDevice.Id().c_str())) - { - // only process the pins for this filter interface. We don't want to - // change anything that has already been built. But we do need the - // context of all pins when getting the group index. - continue; - } + // TODO: If the existing endpoint definition doesn't have room for this set of pins, we need to create a new one - // Figure out the group index for the pin. This needs the context of the - // entire device. Failure to get the group index is fatal - RETURN_IF_FAILED(IncrementAndGetNextGroupIndex(endpointDefinition, pinDefinition.DataFlowFromUserPerspective, pinDefinition.GroupIndex)); - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_VERBOSE, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Assigned Group Index to pin", MIDI_TRACE_EVENT_MESSAGE_FIELD), - TraceLoggingWideString(filterDevice.Id().c_str(), "filter device id"), - TraceLoggingUInt8(pinDefinition.GroupIndex, "group index"), - TraceLoggingWideString(pinDefinition.DataFlowFromUserPerspective == - MidiFlow::MidiFlowOut ? L"MidiFlowOut" : L"MidiFlowIn", "data flow from user's perspective") - ); - std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update - // Build the name table entry for this individual pin - endpointDefinition->EndpointNameTable.PopulateEntryForMidi1DeviceUsingMidi1Driver( - pinDefinition.GroupIndex, - pinDefinition.DataFlowFromUserPerspective, - customName, - driverSuppliedName, - pinDefinition.FilterName, - pinDefinition.PinName, - pinDefinition.PortIndexWithinThisFilterAndDirection - ); + // add our new pins into the existing endpoint definition + endpointDefinition->MidiPins.insert(endpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); + pinList.clear(); // just make sure we don't use this one, accidentally } + RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), parentDeviceDefinition->DriverSuppliedDeviceName, endpointDefinition)); + // we have an endpoint definition m_endpointCreationThreadWakeup.SetEvent(); @@ -2008,16 +1910,14 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved( // find an active device with this filter - std::shared_ptr endpointDefinition{ nullptr }; - - + std::shared_ptr endpointDefinition{ nullptr }; - for (auto& endpointListIterator : m_availableEndpointDefinitions) + for (auto& endpointListIterator : m_activatedEndpointDefinitions) { // check pins for this filter for (auto& pin: endpointListIterator.second->MidiPins) { - if (internal::NormalizeDeviceInstanceIdWStringCopy(pin.FilterDeviceId) == removedFilterDeviceId) + if (internal::NormalizeDeviceInstanceIdWStringCopy(pin->FilterDeviceId) == removedFilterDeviceId) { endpointDefinition = endpointListIterator.second; break; @@ -2035,7 +1935,7 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved( auto foundIt = std::find_if( endpointDefinition->MidiPins.begin(), endpointDefinition->MidiPins.end(), - [&removedFilterDeviceId](KsAggregateEndpointMidiPinDefinition& pin) { return internal::NormalizeDeviceInstanceIdWStringCopy(pin.FilterDeviceId) == removedFilterDeviceId; } + [&removedFilterDeviceId](std::shared_ptr pin) { return internal::NormalizeDeviceInstanceIdWStringCopy(pin->FilterDeviceId) == removedFilterDeviceId; } ); if (foundIt != endpointDefinition->MidiPins.end()) @@ -2058,13 +1958,24 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved( // TODO: Need to cache the name from the driver/registry so we don't have to do a lookup here. - // update remaining pins in existing endpoint definition - RETURN_IF_FAILED(UpdateNewPinDefinitions(removedFilterDeviceId, L"", endpointDefinition)); - RETURN_IF_FAILED(UpdateExistingMidiUmpEndpointWithFilterChanges(endpointDefinition)); + std::shared_ptr parentDeviceDefinition{ nullptr }; + + if (SUCCEEDED(FindExistingParentDeviceDefinitionForEndpoint(endpointDefinition, parentDeviceDefinition))) + { + RETURN_HR_IF_NULL(E_UNEXPECTED, parentDeviceDefinition); + + // update remaining pins in existing endpoint definition + RETURN_IF_FAILED(UpdateNewPinDefinitions(removedFilterDeviceId, parentDeviceDefinition->DriverSuppliedDeviceName, endpointDefinition)); + RETURN_IF_FAILED(DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(endpointDefinition)); + } + else + { + RETURN_IF_FAILED(E_NOTFOUND); + } } else { - auto lock = m_availableEndpointDefinitionsLock.lock(); + auto lock = m_activatedEndpointDefinitionsLock.lock(); // notify the device manager using the InstanceId for this midi device RETURN_IF_FAILED(m_midiDeviceManager->RemoveEndpoint( @@ -2072,7 +1983,7 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved( // remove the endpoint from the list - m_availableEndpointDefinitions.erase(internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->ParentDeviceInstanceId)); + m_activatedEndpointDefinitions.erase(internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->ParentDeviceInstanceId)); } } @@ -2132,7 +2043,8 @@ CMidi2KSAggregateMidiEndpointManager2::OnEnumerationCompleted(DeviceWatcher watc _Use_decl_annotations_ -winrt::hstring CMidi2KSAggregateMidiEndpointManager2::FindMatchingInstantiatedEndpoint(WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria& criteria) +winrt::hstring CMidi2KSAggregateMidiEndpointManager2::FindMatchingInstantiatedEndpoint( + WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria& criteria) { criteria.Normalize(); @@ -2140,13 +2052,20 @@ winrt::hstring CMidi2KSAggregateMidiEndpointManager2::FindMatchingInstantiatedEn { WindowsMidiServicesPluginConfigurationLib::MidiEndpointMatchCriteria available{}; - available.DeviceInstanceId = def.second->DeviceInstanceId; + available.DeviceInstanceId = def.second->EndpointDeviceInstanceId; available.EndpointDeviceId = def.second->EndpointDeviceId; - available.UsbVendorId = def.second->VID; - available.UsbProductId = def.second->PID; - available.UsbSerialNumber = def.second->SerialNumber; available.TransportSuppliedEndpointName = def.second->EndpointName; - available.DeviceManufacturerName = def.second->ManufacturerName; + + std::shared_ptr parent { nullptr }; + + if (SUCCEEDED(FindExistingParentDeviceDefinitionForEndpoint(def.second, parent))) + { + available.UsbVendorId = parent->VID; + available.UsbProductId = parent->PID; + available.UsbSerialNumber = parent->SerialNumber; + available.DeviceManufacturerName = parent->ManufacturerName; + } + if (available.Matches(criteria)) { diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h index b818b40ca..98c265e08 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h @@ -47,7 +47,12 @@ struct KsAggregateEndpointDefinition2 WindowsMidiServicesNamingLib::MidiEndpointNameTable EndpointNameTable{ }; - uint16_t EndpointIndexForThisParentDevice{ 0 }; + uint32_t EndpointIndexForThisParentDevice{ 0 }; + + + int8_t CurrentHighestMidiSourceGroupIndex{ -1 }; + int8_t CurrentHighestMidiDestinationGroupIndex{ -1 }; + }; @@ -58,7 +63,7 @@ class KsAggregateParentDeviceDefinition2 std::wstring DeviceInstanceId{}; std::wstring DriverSuppliedDeviceName{}; // value from registry. Required for WinMM classic naming - std::wstring NameDisambiguatorPrefix{}; // for when there are multiple of the same device attached + uint32_t IndexOfDevicesWithThisSameName{ 0 }; // for when there are multiple of the same device uint16_t VID{ 0 }; // USB-only @@ -114,27 +119,32 @@ class CMidi2KSAggregateMidiEndpointManager2 : _In_ std::wstring deviceInstanceId); HRESULT ParseParentIdIntoVidPidSerial( - _In_ std::wstring systemDevicesParentValue, - _In_ std::shared_ptr& parentDevice); + _In_ std::wstring systemDevicesParentValue, + _In_ std::shared_ptr& parentDevice); HRESULT FindActivatedEndpointDefinitionForFilterDevice( _In_ std::wstring filterDeviceId, _In_ std::shared_ptr&); + HRESULT FindExistingParentDeviceDefinitionForEndpoint( + _In_ std::shared_ptr endpointDefinition, + _In_ std::shared_ptr& parentDeviceDefinition); HRESULT FindOrCreateParentDeviceDefinitionForFilterDevice( - DeviceInformation filterDevice, - std::shared_ptr& parentDeviceDefinition); + _In_ DeviceInformation filterDevice, + _In_ KsHandleWrapper& filterDeviceWrapper, + _In_ std::shared_ptr& parentDeviceDefinition); HRESULT FindOrCreatePendingEndpointDefinitionForFilterDevice( _In_ DeviceInformation, + _In_ KsHandleWrapper& filterDeviceHandleWrapper, _In_ std::shared_ptr&); HRESULT FindCurrentMaxEndpointIndexForParentDevice( _In_ std::shared_ptr parentDeviceDefinition, - _In_ uint16_t& currentMaxIndex); + _In_ uint32_t& currentMaxIndex); HRESULT GetPinName(_In_ HANDLE const hFilter, _In_ UINT const pinIndex, _Inout_ std::wstring& pinName); @@ -142,15 +152,15 @@ class CMidi2KSAggregateMidiEndpointManager2 : HRESULT GetMidi1FilterPins( _In_ DeviceInformation, - _In_ std::vector&); + _In_ std::vector>&); HRESULT GetKSDriverSuppliedName(_In_ HANDLE hFilter, _Inout_ std::wstring& name); - //HRESULT IncrementAndGetNextGroupIndex( - // _In_ std::shared_ptr definition, - // _In_ MidiFlow dataFlowFromUserPerspective, - // _In_ uint8_t& groupIndex); + HRESULT IncrementAndGetNextGroupIndex( + _In_ std::shared_ptr definition, + _In_ MidiFlow dataFlowFromUserPerspective, + _In_ uint8_t& groupIndex); HRESULT UpdateNewPinDefinitions( _In_ std::wstring filterDeviceid, @@ -170,12 +180,10 @@ class CMidi2KSAggregateMidiEndpointManager2 : // these two functions actually update the software devices in Windows HRESULT DeviceCreateMidiUmpEndpoint( - _In_ std::shared_ptr masterEndpointDefinition, - _In_ std::shared_ptr parentDevice); + _In_ std::shared_ptr masterEndpointDefinition); HRESULT DeviceUpdateExistingMidiUmpEndpointWithFilterChanges( - _In_ std::shared_ptr masterEndpointDefinition, - _In_ std::shared_ptr parentDevice); + _In_ std::shared_ptr masterEndpointDefinition); wil::unique_event_nothrow m_endpointCreationThreadWakeup; From b3501877ebbc4f57296db4155751f43ed6c2f9ed Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Mon, 16 Feb 2026 20:00:49 -0500 Subject: [PATCH 15/18] Build artifacts --- build/staging/version/BundleInfo.wxi | 4 ++-- build/staging/version/WindowsMidiServicesVersion.cs | 12 ++++++------ build/staging/version/WindowsMidiServicesVersion.h | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/build/staging/version/BundleInfo.wxi b/build/staging/version/BundleInfo.wxi index c352c6d40..0f37e7c4d 100644 --- a/build/staging/version/BundleInfo.wxi +++ b/build/staging/version/BundleInfo.wxi @@ -1,5 +1,5 @@ - - + + diff --git a/build/staging/version/WindowsMidiServicesVersion.cs b/build/staging/version/WindowsMidiServicesVersion.cs index c211cdaf8..f385244e6 100644 --- a/build/staging/version/WindowsMidiServicesVersion.cs +++ b/build/staging/version/WindowsMidiServicesVersion.cs @@ -9,15 +9,15 @@ public static class MidiNuGetBuildInformation { public const bool IsPreview = true; public const string Source = "GitHub Preview"; - public const string BuildDate = "2026-02-07"; + public const string BuildDate = "2026-02-16"; public const string Name = "Service Preview 14"; - public const string BuildFullVersion = "1.0.15-preview.14.74"; + public const string BuildFullVersion = "1.0.15-preview.14.79"; public const ushort VersionMajor = 1; public const ushort VersionMinor = 0; public const ushort VersionPatch = 15; - public const ushort VersionBuildNumber = 74; - public const string Preview = "preview.14.74"; - public const string AssemblyFullVersion = "1.0.15.74"; - public const string FileFullVersion = "1.0.15.74"; + public const ushort VersionBuildNumber = 79; + public const string Preview = "preview.14.79"; + public const string AssemblyFullVersion = "1.0.15.79"; + public const string FileFullVersion = "1.0.15.79"; } } diff --git a/build/staging/version/WindowsMidiServicesVersion.h b/build/staging/version/WindowsMidiServicesVersion.h index fb9913615..77ede8148 100644 --- a/build/staging/version/WindowsMidiServicesVersion.h +++ b/build/staging/version/WindowsMidiServicesVersion.h @@ -7,14 +7,14 @@ #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_IS_PREVIEW true #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_SOURCE L"GitHub Preview" -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_DATE L"2026-02-07" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_DATE L"2026-02-16" #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_NAME L"Service Preview 14" -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FULL L"1.0.15-preview.14.74" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FULL L"1.0.15-preview.14.79" #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_MAJOR 1 #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_MINOR 0 #define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_PATCH 15 -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_BUILD_NUMBER 74 -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_PREVIEW L"preview.14.74" -#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FILE L"1.0.15.74" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_BUILD_NUMBER 79 +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_PREVIEW L"preview.14.79" +#define WINDOWS_MIDI_SERVICES_NUGET_BUILD_VERSION_FILE L"1.0.15.79" #endif From 05246ed395c1c36b608112b5acac647f5b8d1878 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Mon, 16 Feb 2026 23:05:25 -0500 Subject: [PATCH 16/18] Working on KSA loopmidi issue Adding interfaces works again. Need to finish support for > 16/32 groups --- ...i2.KSAggregateMidiConfigurationManager.cpp | 20 +- .../Midi2.KSAggregateMidiEndpointManager2.cpp | 351 +++++++++++++----- .../Midi2.KSAggregateMidiEndpointManager2.h | 27 +- .../Midi2.KSAggregateTransport.cpp | 19 +- .../KSAggregateTransport/TransportState.h | 6 +- 5 files changed, 315 insertions(+), 108 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiConfigurationManager.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiConfigurationManager.cpp index 12f393c9a..927ad0e8c 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiConfigurationManager.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiConfigurationManager.cpp @@ -258,12 +258,26 @@ CMidi2KSAggregateMidiConfigurationManager::UpdateConfiguration( // Resolve the EndpointDeviceId in case we matched on something else winrt::hstring matchingEndpointDeviceId{}; - auto em = TransportState::Current().GetEndpointManager(); - if (em != nullptr) + + + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + { + auto em = TransportState::Current().GetEndpointManager2(); + if (em != nullptr) + { + matchingEndpointDeviceId = em->FindMatchingInstantiatedEndpoint(*matchCriteria); + } + } + else { - matchingEndpointDeviceId = em->FindMatchingInstantiatedEndpoint(*matchCriteria); + auto em = TransportState::Current().GetEndpointManager(); + if (em != nullptr) + { + matchingEndpointDeviceId = em->FindMatchingInstantiatedEndpoint(*matchCriteria); + } } + // process all the custom props like Name, Description, Image, etc. LOG_IF_FAILED(ProcessCustomProperties( matchingEndpointDeviceId, diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp index 1f669c4ec..f38b97f3f 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp @@ -893,7 +893,7 @@ _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::ParseParentIdIntoVidPidSerial( std::wstring systemDevicesParentValue, - std::shared_ptr& parentDevice) + std::shared_ptr parentDevice) { RETURN_HR_IF_NULL(E_INVALIDARG, parentDevice); @@ -987,6 +987,58 @@ CMidi2KSAggregateMidiEndpointManager2::ParseParentIdIntoVidPidSerial( return S_OK; } +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::FindPendingEndpointDefinitionForParentDevice( + std::wstring parentDeviceInstanceId, + std::shared_ptr& endpointDefinition) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent deviec instance id") + ); + + auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId); + + for (auto const& endpoint : m_pendingEndpointDefinitions) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(endpoint->ParentDeviceInstanceId) == cleanParentDeviceInstanceId) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Match found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent deviec instance id") + ); + + endpointDefinition = endpoint; + return S_OK; + } + } + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent deviec instance id") + ); + + endpointDefinition = nullptr; + + return E_NOTFOUND; +} + _Use_decl_annotations_ HRESULT @@ -1044,6 +1096,69 @@ CMidi2KSAggregateMidiEndpointManager2::FindActivatedEndpointDefinitionForFilterD return E_NOTFOUND; } +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::FindAllActivatedEndpointDefinitionsForParentDevice( + std::wstring parentDeviceInstanceId, + std::vector>& endpointDefinitions +) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent device instance id") + ); + + auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId); + + bool found { false }; + + for (auto const& endpoint : m_activatedEndpointDefinitions) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(endpoint.second->ParentDeviceInstanceId) == cleanParentDeviceInstanceId) + { + endpointDefinitions.push_back(endpoint.second); + found = true; + } + } + + + if (found) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"One or more matching endpoints found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent device instance id") + ); + + return S_OK; + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No matches found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent device instance id") + ); + + return E_NOTFOUND; + } + +} + + _Use_decl_annotations_ HRESULT @@ -1052,13 +1167,47 @@ CMidi2KSAggregateMidiEndpointManager2::FindExistingParentDeviceDefinitionForEndp std::shared_ptr& parentDeviceDefinition ) { - if (auto parent = m_allParentDeviceDefinitions.find(endpointDefinition->ParentDeviceInstanceId); parent != m_allParentDeviceDefinitions.end()) + RETURN_HR_IF_NULL(E_INVALIDARG, endpointDefinition); + + auto cleanDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(endpointDefinition->ParentDeviceInstanceId); + + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Looking for matching parent", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanDeviceInstanceId.c_str(), "parent device instance id") + ); + + if (auto parent = m_allParentDeviceDefinitions.find(cleanDeviceInstanceId); parent != m_allParentDeviceDefinitions.end()) { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Match found.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanDeviceInstanceId.c_str(), "parent device instance id") + ); + parentDeviceDefinition = parent->second; return S_OK; } + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No match found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanDeviceInstanceId.c_str(), "parent device instance id") + ); + return E_NOTFOUND; } @@ -1067,7 +1216,6 @@ _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::FindOrCreateParentDeviceDefinitionForFilterDevice( DeviceInformation filterDevice, - KsHandleWrapper& filterDeviceHandleWrapper, std::shared_ptr& parentDeviceDefinition ) { @@ -1125,32 +1273,23 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreateParentDeviceDefinitionForFilt RETURN_HR_IF_NULL(E_OUTOFMEMORY, newParentDeviceDefinition); newParentDeviceDefinition->DeviceName = parentDevice.Name(); - newParentDeviceDefinition->DeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDevice.Id().c_str()); + newParentDeviceDefinition->DeviceInstanceId = cleanParentDeviceInstanceId; LOG_IF_FAILED(ParseParentIdIntoVidPidSerial(newParentDeviceDefinition->DeviceInstanceId, newParentDeviceDefinition)); // only some vendor drivers provide an actual manufacturer // and all the in-box drivers just provide the Generic USB Audio string - // TODO: Is "Generic USB Audio" a string that is localized? If so, this - // code will not have the intended effect outside of en-US + // TODO: Is "Generic USB Audio" a string that is localized? If so, this code will not have the intended effect outside of en-US auto manufacturer = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.DeviceManufacturer", parentDevice, L""); + auto manufacturer2 = internal::SafeGetSwdPropertyFromDeviceInformation(L"System.Devices.Manufacturer", parentDevice, L""); if (!manufacturer.empty() && manufacturer != L"(Generic USB Audio)" && manufacturer != L"Microsoft") { newParentDeviceDefinition->ManufacturerName = manufacturer; } - - // Driver-supplied name. This is needed for WinMM backwards compatibility - std::wstring driverSuppliedName{}; - - // Using lamba function to prevent handle from dissapearing when being used. - // we don't log the HRESULT because this is not critical and will often fail, - // adding unnecessary noise to the error logs - filterDeviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { - return GetKSDriverSuppliedName(h, driverSuppliedName); - }); - - newParentDeviceDefinition->DriverSuppliedDeviceName = driverSuppliedName; - + else if (!manufacturer2.empty() && manufacturer2 != L"(Generic USB Audio)" && manufacturer2 != L"Microsoft") + { + newParentDeviceDefinition->ManufacturerName = manufacturer2; + } // Do we need to disambiguate this parent because another of the same device already exists? @@ -1175,6 +1314,16 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreateParentDeviceDefinitionForFilt m_allParentDeviceDefinitions[newParentDeviceDefinition->DeviceInstanceId] = newParentDeviceDefinition; parentDeviceDefinition = newParentDeviceDefinition; + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Parent device definition added.", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(newParentDeviceDefinition->DeviceInstanceId.c_str(), "key (device instance id)") + ); + return S_OK; } @@ -1234,7 +1383,6 @@ _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForFilterDevice( DeviceInformation filterDevice, - KsHandleWrapper& filterDeviceHandleWrapper, std::shared_ptr& endpointDefinition ) { @@ -1253,78 +1401,85 @@ CMidi2KSAggregateMidiEndpointManager2::FindOrCreatePendingEndpointDefinitionForF // this function locks the parent device list for the duration of the call RETURN_IF_FAILED(FindOrCreateParentDeviceDefinitionForFilterDevice( filterDevice, - filterDeviceHandleWrapper, parentDeviceDefinition )); RETURN_HR_IF_NULL(E_POINTER, parentDeviceDefinition); - // TODO: See if we already have an endpoint with space for the number of groups we're going to add - + // See if we already have an endpoint with space for the number of groups we're going to add + std::shared_ptr existingEndpointDefinition { nullptr }; + if (SUCCEEDED(FindPendingEndpointDefinitionForParentDevice(parentDeviceDefinition->DeviceInstanceId, existingEndpointDefinition))) + { + RETURN_HR_IF_NULL(E_UNEXPECTED, existingEndpointDefinition); + // TODO: Need to check to see if this endpoint has enough space for the new pins + endpointDefinition = existingEndpointDefinition; + return S_OK; + } + else + { + // create a new endpoint + auto newEndpointDefinition = std::make_shared(); + RETURN_HR_IF_NULL(E_POINTER, newEndpointDefinition); + // We need to ensure each endpoint has a unique id. They can't all use the ParentDeviceInstanceId as the + // instance id because now some devices will have multiple endpoints. Instead, we need to add a suffix to + // that. We need this to be deterministic and not just a random GUID/number, so that device ids have a + // chance to match up when next enumerated after a restart or connect/disconnect. + auto parentLock = m_allParentDeviceDefinitionsLock.lock(); - // create a new endpoint - auto newEndpointDefinition = std::make_shared(); - RETURN_HR_IF_NULL(E_POINTER, parentDeviceDefinition); + uint32_t endpointIndexForThisParent{ 0 }; + if (SUCCEEDED(FindCurrentMaxEndpointIndexForParentDevice(parentDeviceDefinition, endpointIndexForThisParent))) + { + // increment the number here + endpointIndexForThisParent++; + } - // We need to ensure each endpoint has a unique id. They can't all use the ParentDeviceInstanceId as the - // instance id because now some devices will have multiple endpoints. Instead, we need to add a suffix to - // that. We need this to be deterministic and not just a random GUID/number, so that device ids have a - // chance to match up when next enumerated after a restart or connect/disconnect. + newEndpointDefinition->ParentDeviceInstanceId = parentDeviceDefinition->DeviceInstanceId; + newEndpointDefinition->EndpointIndexForThisParentDevice = endpointIndexForThisParent; - auto parentLock = m_allParentDeviceDefinitionsLock.lock(); + // default hash is the device id. + std::hash hasher; + std::wstring hash; + hash = std::to_wstring(hasher(parentDeviceDefinition->DeviceInstanceId)); - uint32_t endpointIndexForThisParent{ 0 }; - if (SUCCEEDED(FindCurrentMaxEndpointIndexForParentDevice(parentDeviceDefinition, endpointIndexForThisParent))) - { - // increment the number here - endpointIndexForThisParent++; - } + if (endpointIndexForThisParent == 0) + { + newEndpointDefinition->EndpointName = parentDeviceDefinition->DeviceName; + newEndpointDefinition->EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; + } + else + { + // pad the string with "0" characters to the left of the number, up to 3 places total. + // we +1 so the second one is _002 and not _001 + newEndpointDefinition->EndpointDeviceInstanceId = std::format(L"{0}{1}_{2:0>3}", TRANSPORT_INSTANCE_ID_PREFIX, hash, endpointIndexForThisParent + 1); - newEndpointDefinition->EndpointIndexForThisParentDevice = endpointIndexForThisParent; + // add the name disambiguator to the endpoint. We +1 to the index for the same reasons as above. + newEndpointDefinition->EndpointName = std::format(L"{0} ({1})", parentDeviceDefinition->DeviceName, endpointIndexForThisParent + 1); + //newEndpointDefinition->EndpointName = std::format(L"{1} - {0}", parentDeviceDefinition->DeviceName, endpointIndexForThisParent + 1); + } - // default hash is the device id. - std::hash hasher; - std::wstring hash; - hash = std::to_wstring(hasher(parentDeviceDefinition->DeviceInstanceId)); + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_INFO, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Adding pending aggregate UMP endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD) + ); - if (endpointIndexForThisParent == 0) - { - newEndpointDefinition->EndpointName = parentDeviceDefinition->DeviceName; - newEndpointDefinition->EndpointDeviceInstanceId = TRANSPORT_INSTANCE_ID_PREFIX + hash; - } - else - { - // pad the string with "0" characters to the left of the number, up to 3 places total. - // we +1 so the second one is _002 and not _001 - newEndpointDefinition->EndpointDeviceInstanceId = std::format(L"{0}{1}_{2:0>3}", TRANSPORT_INSTANCE_ID_PREFIX, hash, endpointIndexForThisParent + 1); + m_pendingEndpointDefinitions.push_back(newEndpointDefinition); + endpointDefinition = newEndpointDefinition; - // add the name disambiguator to the endpoint. We +1 to the index for the same reasons as above. - newEndpointDefinition->EndpointName = std::format(L"{0} ({1})", parentDeviceDefinition->DeviceName, endpointIndexForThisParent + 1); - //newEndpointDefinition->EndpointName = std::format(L"{1} - {0}", parentDeviceDefinition->DeviceName, endpointIndexForThisParent + 1); + return S_OK; } - TraceLoggingWrite( - MidiKSAggregateTransportTelemetryProvider::Provider(), - MIDI_TRACE_EVENT_INFO, - TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingPointer(this, "this"), - TraceLoggingWideString(L"Adding pending aggregate UMP endpoint.", MIDI_TRACE_EVENT_MESSAGE_FIELD) - ); - - m_pendingEndpointDefinitions.push_back(newEndpointDefinition); - endpointDefinition = newEndpointDefinition; - - - return S_OK; } _Use_decl_annotations_ @@ -1472,7 +1627,6 @@ _Use_decl_annotations_ bool CMidi2KSAggregateMidiEndpointManager2::ActiveKSAEndpointForDeviceExists( _In_ std::wstring parentDeviceInstanceId) { - auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId.c_str()); for (auto const& entry : m_activatedEndpointDefinitions) @@ -1526,6 +1680,18 @@ CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins( KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); RETURN_IF_FAILED(deviceHandleWrapper.Open()); + + + // Driver-supplied name. This is needed for WinMM backwards compatibility + std::wstring driverSuppliedName{}; + + // Using lamba function to prevent handle from dissapearing when being used. + // we don't log the HRESULT because this is not critical and will often fail, + // adding unnecessary noise to the error logs + deviceHandleWrapper.Execute([&](HANDLE h) -> HRESULT { + return GetKSDriverSuppliedName(h, driverSuppliedName); + }); + // ============================================================================================= // Go through all the enumerated pins, looking for a MIDI 1.0 pin @@ -1598,6 +1764,11 @@ CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins( return GetPinDataFlow(h, pinIndex, dataFlow); }); + + + pinDefinition->DriverSuppliedName = driverSuppliedName; + + if (SUCCEEDED(dataFlowHr)) { if (dataFlow == KSPIN_DATAFLOW_IN) @@ -1650,7 +1821,6 @@ _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( std::wstring filterDeviceid, - std::wstring driverSuppliedName, std::shared_ptr endpointDefinition) { // At this point, we need to have *all* the pins for the endpoint, not just this filter @@ -1689,6 +1859,7 @@ CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( // guard against errors resulting in invalid group indexes RETURN_HR_IF(E_UNEXPECTED, !IS_VALID_GROUP_INDEX(pinDefinition->GroupIndex)); + std::wstring customName = L""; // This is blank here. It gets folded in later during endpoint creation/update // Build the name table entry for this individual pin @@ -1696,7 +1867,7 @@ CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( pinDefinition->GroupIndex, pinDefinition->DataFlowFromUserPerspective, customName, - driverSuppliedName, + pinDefinition->DriverSuppliedName, pinDefinition->FilterName, pinDefinition->PinName, pinDefinition->PortIndexWithinThisFilterAndDirection @@ -1799,7 +1970,6 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( RETURN_IF_FAILED(FindOrCreateParentDeviceDefinitionForFilterDevice( filterDevice, - deviceHandleWrapper, parentDeviceDefinition )); @@ -1807,16 +1977,11 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( // =================================================================== // Find or create the endpoint - std::shared_ptr endpointDefinition{ nullptr }; // check to see if we already have an *activated* endpoint for this filter if (ActiveKSAEndpointForDeviceExists(parentInstanceId.c_str())) { - // TODO: If the existing endpoint definition doesn't have room for this set of pins, we need to create a new one - - - TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), MIDI_TRACE_EVENT_VERBOSE, @@ -1830,13 +1995,29 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( std::shared_ptr existingActivatedEndpointDefinition { nullptr }; - // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindActivatedEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); - RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); + // Get all existing endpoint definitions for this parent device + // If the existing endpoint definition doesn't have room for this set of pins, we need to create a new one + + std::vector> foundEndpoints {}; + + if (SUCCEEDED(FindAllActivatedEndpointDefinitionsForParentDevice(parentInstanceId.c_str(), foundEndpoints))) + { + // find an endpoint with room for another interface with pins. + // We're going by the highest group index + + + // TEMP + existingActivatedEndpointDefinition = foundEndpoints[0]; + } + + + //// first MIDI 1 pin we're processing for this interface + //RETURN_IF_FAILED(FindActivatedEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); + //RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); // add our new pins into the existing endpoint definition existingActivatedEndpointDefinition->MidiPins.insert(existingActivatedEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); - RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), parentDeviceDefinition->DriverSuppliedDeviceName, existingActivatedEndpointDefinition)); + RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), existingActivatedEndpointDefinition)); RETURN_IF_FAILED(DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition)); @@ -1864,7 +2045,7 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( if (endpointDefinition == nullptr) { // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindOrCreatePendingEndpointDefinitionForFilterDevice(filterDevice, deviceHandleWrapper, endpointDefinition)); + RETURN_IF_FAILED(FindOrCreatePendingEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); @@ -1879,7 +2060,7 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( pinList.clear(); // just make sure we don't use this one, accidentally } - RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), parentDeviceDefinition->DriverSuppliedDeviceName, endpointDefinition)); + RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), endpointDefinition)); // we have an endpoint definition m_endpointCreationThreadWakeup.SetEvent(); @@ -1965,7 +2146,7 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceRemoved( RETURN_HR_IF_NULL(E_UNEXPECTED, parentDeviceDefinition); // update remaining pins in existing endpoint definition - RETURN_IF_FAILED(UpdateNewPinDefinitions(removedFilterDeviceId, parentDeviceDefinition->DriverSuppliedDeviceName, endpointDefinition)); + RETURN_IF_FAILED(UpdateNewPinDefinitions(removedFilterDeviceId, endpointDefinition)); RETURN_IF_FAILED(DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(endpointDefinition)); } else diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h index 98c265e08..79cc2cdac 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h @@ -24,7 +24,9 @@ struct KsAggregateEndpointMidiPinDefinition2 std::wstring FilterDeviceId; // this is also the value needed by WinMM for DRV_QUERYDEVICEINTERFACE std::wstring FilterName; - ULONG PinNumber; + std::wstring DriverSuppliedName{}; // value from registry. Required for WinMM classic naming. This was at parent level efore, but loopMIDI and similar register per-interface + + ULONG PinNumber{ 0 }; std::wstring PinName; MidiFlow PinDataFlow; @@ -61,7 +63,6 @@ class KsAggregateParentDeviceDefinition2 public: std::wstring DeviceName{}; std::wstring DeviceInstanceId{}; - std::wstring DriverSuppliedDeviceName{}; // value from registry. Required for WinMM classic naming uint32_t IndexOfDevicesWithThisSameName{ 0 }; // for when there are multiple of the same device @@ -120,26 +121,33 @@ class CMidi2KSAggregateMidiEndpointManager2 : HRESULT ParseParentIdIntoVidPidSerial( _In_ std::wstring systemDevicesParentValue, - _In_ std::shared_ptr& parentDevice); + _In_ std::shared_ptr parentDevice); HRESULT FindActivatedEndpointDefinitionForFilterDevice( _In_ std::wstring filterDeviceId, - _In_ std::shared_ptr&); + _Inout_ std::shared_ptr&); + + HRESULT FindAllActivatedEndpointDefinitionsForParentDevice( + _In_ std::wstring parentDeviceInstanceId, + _Inout_ std::vector>& endpointDefinitions + ); + + HRESULT FindPendingEndpointDefinitionForParentDevice( + _In_ std::wstring parentDeviceInstanceId, + _Inout_ std::shared_ptr&); HRESULT FindExistingParentDeviceDefinitionForEndpoint( _In_ std::shared_ptr endpointDefinition, - _In_ std::shared_ptr& parentDeviceDefinition); + _Inout_ std::shared_ptr& parentDeviceDefinition); HRESULT FindOrCreateParentDeviceDefinitionForFilterDevice( _In_ DeviceInformation filterDevice, - _In_ KsHandleWrapper& filterDeviceWrapper, - _In_ std::shared_ptr& parentDeviceDefinition); + _Inout_ std::shared_ptr& parentDeviceDefinition); HRESULT FindOrCreatePendingEndpointDefinitionForFilterDevice( _In_ DeviceInformation, - _In_ KsHandleWrapper& filterDeviceHandleWrapper, - _In_ std::shared_ptr&); + _Inout_ std::shared_ptr&); HRESULT FindCurrentMaxEndpointIndexForParentDevice( @@ -164,7 +172,6 @@ class CMidi2KSAggregateMidiEndpointManager2 : HRESULT UpdateNewPinDefinitions( _In_ std::wstring filterDeviceid, - _In_ std::wstring driverSuppliedName, _In_ std::shared_ptr endpointDefinition); HRESULT BuildPinsAndGroupTerminalBlocksPropertyData( diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp index e8cd71c79..3f9ada055 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateTransport.cpp @@ -44,19 +44,24 @@ CMidi2KSAggregateTransport::Activate( TraceLoggingWideString(L"IMidiEndpointManager", MIDI_TRACE_EVENT_INTERFACE_FIELD) ); - // check to see if this is the first time we're creating the endpoint manager. If so, create it. - if (TransportState::Current().GetEndpointManager() == nullptr) - { - TransportState::Current().ConstructEndpointManager(); - } - - if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { + // check to see if this is the first time we're creating the endpoint manager. If so, create it. + if (TransportState::Current().GetEndpointManager2() == nullptr) + { + TransportState::Current().ConstructEndpointManager(); + } + RETURN_IF_FAILED(TransportState::Current().GetEndpointManager2()->QueryInterface(iid, activatedInterface)); } else { + // check to see if this is the first time we're creating the endpoint manager. If so, create it. + if (TransportState::Current().GetEndpointManager() == nullptr) + { + TransportState::Current().ConstructEndpointManager(); + } + RETURN_IF_FAILED(TransportState::Current().GetEndpointManager()->QueryInterface(iid, activatedInterface)); } diff --git a/src/api/Transport/KSAggregateTransport/TransportState.h b/src/api/Transport/KSAggregateTransport/TransportState.h index 57726be7b..3b7928de7 100644 --- a/src/api/Transport/KSAggregateTransport/TransportState.h +++ b/src/api/Transport/KSAggregateTransport/TransportState.h @@ -23,13 +23,13 @@ class TransportState wil::com_ptr GetEndpointManager() { - if (!Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) + if (Feature_Servicing_MIDI2VirtualPortDriversFix::IsEnabled()) { - return m_endpointManager; + return nullptr; } else { - return nullptr; + return m_endpointManager; } } From 54be453edc18d68213bca1382b04c2e7c18fe0a5 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Tue, 17 Feb 2026 00:11:10 -0500 Subject: [PATCH 17/18] Working on counts > 16 groups in either direction --- .../Midi2.KSAggregateMidiEndpointManager2.cpp | 237 ++++++++++++++---- .../Midi2.KSAggregateMidiEndpointManager2.h | 12 +- 2 files changed, 199 insertions(+), 50 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp index f38b97f3f..937de3f83 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp @@ -1158,6 +1158,68 @@ CMidi2KSAggregateMidiEndpointManager2::FindAllActivatedEndpointDefinitionsForPar } +_Use_decl_annotations_ +HRESULT +CMidi2KSAggregateMidiEndpointManager2::FindAllPendingEndpointDefinitionsForParentDevice( + std::wstring parentDeviceInstanceId, + std::vector>& endpointDefinitions +) +{ + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(parentDeviceInstanceId.c_str(), "parent device instance id") + ); + + auto cleanParentDeviceInstanceId = internal::NormalizeDeviceInstanceIdWStringCopy(parentDeviceInstanceId); + + bool found{ false }; + + for (auto const& endpoint : m_pendingEndpointDefinitions) + { + if (internal::NormalizeDeviceInstanceIdWStringCopy(endpoint->ParentDeviceInstanceId) == cleanParentDeviceInstanceId) + { + endpointDefinitions.push_back(endpoint); + found = true; + } + } + + + if (found) + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"One or more matching endpoints found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent device instance id") + ); + + return S_OK; + } + else + { + TraceLoggingWrite( + MidiKSAggregateTransportTelemetryProvider::Provider(), + MIDI_TRACE_EVENT_VERBOSE, + TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingPointer(this, "this"), + TraceLoggingWideString(L"No matches found", MIDI_TRACE_EVENT_MESSAGE_FIELD), + TraceLoggingWideString(cleanParentDeviceInstanceId.c_str(), "parent device instance id") + ); + + return E_NOTFOUND; + } + +} + _Use_decl_annotations_ @@ -1673,14 +1735,17 @@ _Use_decl_annotations_ HRESULT CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins( DeviceInformation filterDevice, - std::vector>& pinListToAddTo -) + std::vector>& pinListToAddTo, + uint8_t& countMidiSourcePinsAdded, + uint8_t& countMidiDestinationPinsAdded + ) { // Wrapper opens the handle internally. KsHandleWrapper deviceHandleWrapper(filterDevice.Id().c_str()); RETURN_IF_FAILED(deviceHandleWrapper.Open()); - + countMidiSourcePinsAdded = 0; + countMidiDestinationPinsAdded = 0; // Driver-supplied name. This is needed for WinMM backwards compatibility std::wstring driverSuppliedName{}; @@ -1779,6 +1844,8 @@ CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins( pinDefinition->PortIndexWithinThisFilterAndDirection = static_cast(midiOutputPinIndexForThisFilter); midiOutputPinIndexForThisFilter++; + + countMidiDestinationPinsAdded++; } else if (dataFlow == KSPIN_DATAFLOW_OUT) { @@ -1788,6 +1855,8 @@ CMidi2KSAggregateMidiEndpointManager2::GetMidi1FilterPins( pinDefinition->PortIndexWithinThisFilterAndDirection = static_cast(midiInputPinIndexForThisFilter); midiInputPinIndexForThisFilter++; + + countMidiSourcePinsAdded++; } pinListToAddTo.push_back(pinDefinition); @@ -1877,6 +1946,39 @@ CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( return S_OK; } +bool EndpointHasRoomForMoreNewPins( + _In_ std::shared_ptr endpoint, + _In_ uint8_t countNewSourcePins, + _In_ uint8_t countNewDestinationPins) +{ + + uint8_t countFoundSourcePins{ 0 }; + uint8_t countFoundDestinationPins{ 0 }; + + // count the source and destination pins + + for (auto const& pin : endpoint->MidiPins) + { + if (pin->DataFlowFromUserPerspective == MidiFlow::MidiFlowIn) + { + countFoundSourcePins++; + } + else + { + countFoundDestinationPins++; + } + } + + if ((countFoundSourcePins + countNewSourcePins <= 16) && + (countFoundDestinationPins + countNewDestinationPins <= 16)) + { + return true; + } + + return false; + +} + _Use_decl_annotations_ HRESULT @@ -1917,7 +2019,9 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( // like group index and naming require the full context of all filters/pins on the // parent device. We want to get these before we try creating parents or endpoints std::vector> pinList{ }; - RETURN_IF_FAILED(GetMidi1FilterPins(filterDevice, pinList)); + uint8_t countEnumeratedMidiSourcePins{ 0 }; + uint8_t countEnumeratedMidiDestinationPins{ 0 }; + RETURN_IF_FAILED(GetMidi1FilterPins(filterDevice, pinList, countEnumeratedMidiSourcePins, countEnumeratedMidiDestinationPins)); if (pinList.size() == 0) { @@ -1953,8 +2057,6 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( } - - // We have MIDI 1.0 pins to process, so we'll need to find or create a parent device // and also find or create an endpoint under that parent, which has sufficient room // for these pins. @@ -1973,15 +2075,52 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( parentDeviceDefinition )); + std::vector> foundEndpoints{}; - // =================================================================== - // Find or create the endpoint + // do we already have one or more pending endpoints for this? + if (SUCCEEDED(FindAllPendingEndpointDefinitionsForParentDevice(parentInstanceId.c_str(), foundEndpoints))) + { + std::shared_ptr existingPendingEndpointDefinition{ nullptr }; - std::shared_ptr endpointDefinition{ nullptr }; + // find an endpoint with room for another interface with pins. + // We're going by the pin counts returned when we enumerated + // all pins for this interface - // check to see if we already have an *activated* endpoint for this filter - if (ActiveKSAEndpointForDeviceExists(parentInstanceId.c_str())) + // check the latest endpoint first + + for (size_t i = foundEndpoints.size() - 1; i >= 0; i--) + { + auto ep = foundEndpoints[i]; + + if (EndpointHasRoomForMoreNewPins(ep, countEnumeratedMidiSourcePins, countEnumeratedMidiDestinationPins)) + { + existingPendingEndpointDefinition = ep; + + break; + } + } + + // if this check fails, we fall through to creating a new endpoint definition + if (existingPendingEndpointDefinition != nullptr) + { + //// first MIDI 1 pin we're processing for this interface + //RETURN_IF_FAILED(FindActivatedEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); + //RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); + + // add our new pins into the existing endpoint definition + existingPendingEndpointDefinition->MidiPins.insert(existingPendingEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); + RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), existingPendingEndpointDefinition)); + + return S_OK; + } + + } + else if (ActiveKSAEndpointForDeviceExists(parentInstanceId.c_str())) { + std::shared_ptr existingActivatedEndpointDefinition{ nullptr }; + + // check to see if we already have any activated endpoints for this device + TraceLoggingWrite( MidiKSAggregateTransportTelemetryProvider::Provider(), MIDI_TRACE_EVENT_VERBOSE, @@ -1993,35 +2132,49 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( TraceLoggingWideString(parentInstanceId.c_str(), "parent instance id") ); - std::shared_ptr existingActivatedEndpointDefinition { nullptr }; - - // Get all existing endpoint definitions for this parent device - // If the existing endpoint definition doesn't have room for this set of pins, we need to create a new one - - std::vector> foundEndpoints {}; + // TODO: Potential problem with this code: + // Because it's just using pin counts, after some add/removes + // the group indexes could change. Is that ok? Seems not. + // need to see what impact that will have on enumerated MIDI ports + if (SUCCEEDED(FindAllActivatedEndpointDefinitionsForParentDevice(parentInstanceId.c_str(), foundEndpoints))) { // find an endpoint with room for another interface with pins. - // We're going by the highest group index + // We're going by the pin counts returned when we enumerated + // all pins for this interface + // check the latest endpoint first + + for (size_t i = foundEndpoints.size()-1; i >= 0; i--) + { + auto ep = foundEndpoints[i]; - // TEMP - existingActivatedEndpointDefinition = foundEndpoints[0]; - } + if (EndpointHasRoomForMoreNewPins(ep, countEnumeratedMidiSourcePins, countEnumeratedMidiDestinationPins)) + { + existingActivatedEndpointDefinition = ep; + break; + } + + } + } - //// first MIDI 1 pin we're processing for this interface - //RETURN_IF_FAILED(FindActivatedEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); - //RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); + // if this check fails, we fall through to creating a new endpoint definition + if (existingActivatedEndpointDefinition != nullptr) + { + //// first MIDI 1 pin we're processing for this interface + //RETURN_IF_FAILED(FindActivatedEndpointDefinitionForFilterDevice(parentInstanceId.c_str(), existingActivatedEndpointDefinition)); + //RETURN_HR_IF_NULL(E_POINTER, existingActivatedEndpointDefinition); - // add our new pins into the existing endpoint definition - existingActivatedEndpointDefinition->MidiPins.insert(existingActivatedEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); - RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), existingActivatedEndpointDefinition)); + // add our new pins into the existing endpoint definition + existingActivatedEndpointDefinition->MidiPins.insert(existingActivatedEndpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); + RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), existingActivatedEndpointDefinition)); - RETURN_IF_FAILED(DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition)); + RETURN_IF_FAILED(DeviceUpdateExistingMidiUmpEndpointWithFilterChanges(existingActivatedEndpointDefinition)); - return S_OK; + return S_OK; + } } else { @@ -2037,28 +2190,18 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( ); } - // if the endpointDefinition is null, that means we haven't found an existing - // activated endpoint definition we need to use, and so we proceed to check - // for an existing pending endpoint definition. If found, it's used. If not - // found, the function will create a new one for us to use, with all the - // endpoint-specific details (excluding pins) populated. - if (endpointDefinition == nullptr) - { - // first MIDI 1 pin we're processing for this interface - RETURN_IF_FAILED(FindOrCreatePendingEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); - RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); - - - - // TODO: If the existing endpoint definition doesn't have room for this set of pins, we need to create a new one + // =================================================================== + // Create the endpoint + std::shared_ptr endpointDefinition{ nullptr }; + RETURN_IF_FAILED(FindOrCreatePendingEndpointDefinitionForFilterDevice(filterDevice, endpointDefinition)); + RETURN_HR_IF_NULL(E_POINTER, endpointDefinition); - // add our new pins into the existing endpoint definition - endpointDefinition->MidiPins.insert(endpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); - pinList.clear(); // just make sure we don't use this one, accidentally - } + // add our new pins + endpointDefinition->MidiPins.insert(endpointDefinition->MidiPins.end(), pinList.begin(), pinList.end()); + pinList.clear(); // just make sure we don't use this one, accidentally RETURN_IF_FAILED(UpdateNewPinDefinitions(filterDevice.Id().c_str(), endpointDefinition)); diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h index 79cc2cdac..06ca403a9 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.h @@ -130,8 +130,12 @@ class CMidi2KSAggregateMidiEndpointManager2 : HRESULT FindAllActivatedEndpointDefinitionsForParentDevice( _In_ std::wstring parentDeviceInstanceId, - _Inout_ std::vector>& endpointDefinitions - ); + _Inout_ std::vector>& endpointDefinitions); + + HRESULT FindAllPendingEndpointDefinitionsForParentDevice( + _In_ std::wstring parentDeviceInstanceId, + _Inout_ std::vector>& endpointDefinitions); + HRESULT FindPendingEndpointDefinitionForParentDevice( _In_ std::wstring parentDeviceInstanceId, @@ -160,7 +164,9 @@ class CMidi2KSAggregateMidiEndpointManager2 : HRESULT GetMidi1FilterPins( _In_ DeviceInformation, - _In_ std::vector>&); + _In_ std::vector>&, + _Inout_ uint8_t& countMidiSourcePinsAdded, + _Inout_ uint8_t& countMidiDestinationPinsAdded); HRESULT GetKSDriverSuppliedName(_In_ HANDLE hFilter, _Inout_ std::wstring& name); From faf3262986ae27d6f17fadd069b2514a1abf10d5 Mon Sep 17 00:00:00 2001 From: Pete Brown Date: Tue, 17 Feb 2026 11:14:21 -0500 Subject: [PATCH 18/18] Update Midi2.KSAggregateMidiEndpointManager2.cpp --- .../Midi2.KSAggregateMidiEndpointManager2.cpp | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp index 937de3f83..a0b8d1a8e 100644 --- a/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp +++ b/src/api/Transport/KSAggregateTransport/Midi2.KSAggregateMidiEndpointManager2.cpp @@ -1946,11 +1946,19 @@ CMidi2KSAggregateMidiEndpointManager2::UpdateNewPinDefinitions( return S_OK; } + +// TODO: Potential problem with this code: +// Because it's just using pin counts, after some add/removes +// the group indexes could change. Is that ok? Seems not. +// need to see what impact that will have on enumerated MIDI ports + bool EndpointHasRoomForMoreNewPins( _In_ std::shared_ptr endpoint, _In_ uint8_t countNewSourcePins, _In_ uint8_t countNewDestinationPins) { + if (endpoint == nullptr) return false; + uint8_t countFoundSourcePins{ 0 }; uint8_t countFoundDestinationPins{ 0 }; @@ -2023,6 +2031,23 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( uint8_t countEnumeratedMidiDestinationPins{ 0 }; RETURN_IF_FAILED(GetMidi1FilterPins(filterDevice, pinList, countEnumeratedMidiSourcePins, countEnumeratedMidiDestinationPins)); + + + + // TODO: Should just get that list of pins back, and even if it's > 32, just break it up into separate + // endpoints of pins (16 in, 16 out, max). May need a "distribute pins" type of function + // to return a vector of vectors of pins + // Can change the GetMidi1FilterPins to return a vector of KsAggregateEndpointMidiPinList entries, + // each of which has a list of in pins and out pins or whatever ends up convenient + // + // but we cannot have the same filter opened by two different endpoints, so we're effectively limited to + // 16 in/16 out per filter + // + + + + + if (pinList.size() == 0) { TraceLoggingWrite( @@ -2088,6 +2113,8 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( // check the latest endpoint first + + for (size_t i = foundEndpoints.size() - 1; i >= 0; i--) { auto ep = foundEndpoints[i]; @@ -2133,10 +2160,7 @@ CMidi2KSAggregateMidiEndpointManager2::OnFilterDeviceInterfaceAdded( ); - // TODO: Potential problem with this code: - // Because it's just using pin counts, after some add/removes - // the group indexes could change. Is that ok? Seems not. - // need to see what impact that will have on enumerated MIDI ports + foundEndpoints.clear(); if (SUCCEEDED(FindAllActivatedEndpointDefinitionsForParentDevice(parentInstanceId.c_str(), foundEndpoints))) {