From 7034a41de74287a48c5c1b6552f40d2d9fa7af67 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Wed, 11 Mar 2026 13:58:49 +0000 Subject: [PATCH 01/11] fix: prevent false-positive gap detection on zero-consumption periods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The gap detector in previous_days_modal_filter() checks if consecutive values are equal (data[m] == data[m+5]) to find missing data. After clean_incrementing_reverse(), zero-consumption overnight periods have equal consecutive values, triggering false gap detection and injecting phantom load (~6 kWh/night for a 24 kWh/day average). Track sensor data point provenance during minute_data() processing via a new data_point_minutes set parameter. In the gap detector, check whether the sensor was actively reporting during each gap period. If the sensor was online (≥1 data point/hour), skip filling. If offline, fill as before. Supersedes #3546 which attempted to fix the symptom via interpolation. Co-Authored-By: Claude Opus 4.6 --- apps/predbat/fetch.py | 22 +- .../tests/test_fill_load_from_power.py | 525 +++++++++++++++++- .../predbat/tests/test_previous_days_modal.py | 116 ++++ apps/predbat/utils.py | 3 + 4 files changed, 662 insertions(+), 4 deletions(-) diff --git a/apps/predbat/fetch.py b/apps/predbat/fetch.py index 7809232d0..34802436e 100644 --- a/apps/predbat/fetch.py +++ b/apps/predbat/fetch.py @@ -436,6 +436,20 @@ def previous_days_modal_filter(self, data): num_gaps += gap_minutes gap_list.append((gap_start_minute_previous, gap_minutes)) + # Filter false-positive gaps where sensor was actively reporting + if hasattr(self, "load_data_point_minutes") and self.load_data_point_minutes: + filtered_gaps = [] + for gap_start, gap_minutes_len in gap_list: + gap_data_count = sum(1 for m in self.load_data_point_minutes if gap_start <= m < gap_start + gap_minutes_len) + # Need at least 1 data point per hour (min 2) to consider sensor "active" + min_data_points = max(gap_minutes_len // 60, 2) + if gap_data_count >= min_data_points: + self.log("Info: Skipping gap at minute {} ({} min) - sensor active ({} of {} points)".format(gap_start, gap_minutes_len, gap_data_count, min_data_points)) + else: + filtered_gaps.append((gap_start, gap_minutes_len)) + gap_list = filtered_gaps + num_gaps = sum(g[1] for g in gap_list) + # Work out total number of gap_minutes if num_gaps > 0: self.log("Warn: Found {} gaps in load_today totalling {} minutes to fill using average data".format(len(gap_list), num_gaps)) @@ -588,7 +602,7 @@ def minute_data_import_export(self, max_days_previous, now_utc, key, scale=1.0, return import_today - def minute_data_load(self, now_utc, entity_name, max_days_previous, load_scaling=1.0, required_unit=None, interpolate=False, pad=True): + def minute_data_load(self, now_utc, entity_name, max_days_previous, load_scaling=1.0, required_unit=None, interpolate=False, pad=True, data_point_minutes=None): """ Download one or more entities for load data """ @@ -639,6 +653,7 @@ def minute_data_load(self, now_utc, entity_name, max_days_previous, load_scaling accumulate=load_minutes, required_unit=required_unit, interpolate=interpolate, + data_point_minutes=data_point_minutes, ) else: if history is None: @@ -680,6 +695,7 @@ def fetch_sensor_data(self, save=True): self.pv_today = {} self.load_minutes = {} self.load_minutes_age = 0 + self.load_data_point_minutes = set() self.load_forecast = {} self.load_forecast_array = [] self.pv_forecast_minute = {} @@ -733,7 +749,7 @@ def fetch_sensor_data(self, save=True): else: # Load data if "load_today" in self.args: - self.load_minutes, self.load_minutes_age = self.minute_data_load(self.now_utc, "load_today", self.max_days_previous, required_unit="kWh", load_scaling=1.0, interpolate=True) + self.load_minutes, self.load_minutes_age = self.minute_data_load(self.now_utc, "load_today", self.max_days_previous, required_unit="kWh", load_scaling=1.0, interpolate=True, data_point_minutes=self.load_data_point_minutes) self.log("Found {} load_today datapoints going back {} days".format(len(self.load_minutes), self.load_minutes_age)) self.load_minutes_now = get_now_from_cumulative(self.load_minutes, self.minutes_now, backwards=True) self.load_last_period = (self.load_minutes.get(0, 0) - self.load_minutes.get(PREDICT_STEP, 0)) * 60 / PREDICT_STEP @@ -1283,7 +1299,7 @@ def download_ge_data(self, now_utc): age = now_utc - oldest_data_time self.load_minutes_age = age.days - self.load_minutes, _ = minute_data(mdata, self.max_days_previous, now_utc, "consumption", "last_updated", backwards=True, smoothing=True, scale=1.0, clean_increment=True, interpolate=True) + self.load_minutes, _ = minute_data(mdata, self.max_days_previous, now_utc, "consumption", "last_updated", backwards=True, smoothing=True, scale=1.0, clean_increment=True, interpolate=True, data_point_minutes=self.load_data_point_minutes) self.import_today, _ = minute_data(mdata, self.max_days_previous, now_utc, "import", "last_updated", backwards=True, smoothing=True, scale=self.import_export_scaling, clean_increment=True) self.export_today, _ = minute_data(mdata, self.max_days_previous, now_utc, "export", "last_updated", backwards=True, smoothing=True, scale=self.import_export_scaling, clean_increment=True) self.pv_today, _ = minute_data(mdata, self.max_days_previous, now_utc, "pv", "last_updated", backwards=True, smoothing=True, scale=self.import_export_scaling, clean_increment=True) diff --git a/apps/predbat/tests/test_fill_load_from_power.py b/apps/predbat/tests/test_fill_load_from_power.py index a68c8e4a8..393fdad1e 100644 --- a/apps/predbat/tests/test_fill_load_from_power.py +++ b/apps/predbat/tests/test_fill_load_from_power.py @@ -4,12 +4,13 @@ import sys import os +from datetime import datetime, timezone, timedelta # Add the apps/predbat directory to the path sys.path.append(os.path.join(os.path.dirname(__file__), "..", "apps", "predbat")) from fetch import Fetch -from utils import dp4 +from utils import dp4, minute_data class TestFetch(Fetch): @@ -305,6 +306,523 @@ def test_fill_load_from_power_backwards_time(): print("Test 6 PASSED") +def generate_ge_cloud_history(days=8, now_utc=None): + """Generate realistic sparse 5-minute GE Cloud consumption data. + + Returns a list of dicts sorted oldest-first with cumulative 'consumption' + values and 'last_updated' ISO timestamps, mimicking GE Cloud history. + """ + if now_utc is None: + now_utc = datetime.now(timezone.utc).replace(second=0, microsecond=0) + + history = [] + cumulative = 0.0 + + # Generate from oldest to newest (ascending time) + start_time = now_utc - timedelta(days=days) + current_time = start_time + + while current_time <= now_utc: + hour = current_time.hour + + # Consumption rate varies by time of day (kWh per 5 minutes) + if 23 <= hour or hour < 6: + # Overnight: very low consumption (standby ~0.1 kW) + consumption_rate = 0.1 / 12 # kWh per 5-min interval + elif 6 <= hour < 9: + # Morning peak + consumption_rate = 1.5 / 12 + elif 9 <= hour < 17: + # Daytime moderate + consumption_rate = 0.8 / 12 + elif 17 <= hour < 23: + # Evening peak + consumption_rate = 2.0 / 12 + + cumulative += consumption_rate + + history.append( + { + "consumption": round(cumulative, 4), + "last_updated": current_time.strftime("%Y-%m-%dT%H:%M:%S+00:00"), + } + ) + + current_time += timedelta(minutes=5) + + return history + + +def interpolate_sparse_data_local(data): + """Local implementation of linear interpolation between sparse points. + + Given a dict of {minute: value}, finds gaps (where consecutive minutes + are missing) and linearly interpolates between the bounding values. + Returns a new dict with all gaps filled. + """ + if not data: + return {} + + result = dict(data) + min_minute = min(data.keys()) + max_minute = max(data.keys()) + + # Find filled points and sort them + filled_minutes = sorted(data.keys()) + if len(filled_minutes) < 2: + return result + + # Interpolate between each pair of filled points + for i in range(len(filled_minutes) - 1): + start_m = filled_minutes[i] + end_m = filled_minutes[i + 1] + gap = end_m - start_m + + if gap <= 1: + continue # No gap to fill + + start_val = data[start_m] + end_val = data[end_m] + + for m in range(start_m + 1, end_m): + frac = (m - start_m) / gap + result[m] = dp4(start_val + (end_val - start_val) * frac) + + return result + + +def test_minute_data_densifies_sparse_ge_cloud_data(): + """ + Prove that minute_data() with smoothing=True and clean_increment=True + produces fully dense per-minute output from sparse 5-minute GE Cloud data, + making any subsequent interpolation a no-op. + """ + print("\n=== Test 7: minute_data densifies sparse GE Cloud data ===") + + days = 8 + now_utc = datetime(2026, 3, 10, 12, 0, 0, tzinfo=timezone.utc) + history = generate_ge_cloud_history(days=days, now_utc=now_utc) + + print(f" Generated {len(history)} sparse history points over {days} days") + print(f" First: {history[0]['last_updated']} = {history[0]['consumption']}") + print(f" Last: {history[-1]['last_updated']} = {history[-1]['consumption']}") + + # Call minute_data with the EXACT params from download_ge_data (fetch.py:1286) + result, _ = minute_data( + history, + days, + now_utc, + "consumption", + "last_updated", + backwards=True, + smoothing=True, + scale=1.0, + clean_increment=True, + interpolate=True, + ) + + # Check density: every minute from 0 to 8*24*60 - 1 should have a value + total_minutes = days * 24 * 60 + filled = sum(1 for m in range(total_minutes) if m in result) + missing = total_minutes - filled + + print(f" Total minutes expected: {total_minutes}") + print(f" Minutes filled: {filled}") + print(f" Minutes missing: {missing}") + + assert missing == 0, f"minute_data left {missing} gaps out of {total_minutes} minutes" + + # Now run our local interpolate_sparse_data on the result and count changes + interpolated = interpolate_sparse_data_local(result) + + changes = 0 + for m in range(total_minutes): + if m in result and m in interpolated: + if abs(result[m] - interpolated[m]) > 0.0001: + changes += 1 + + print(f" Values changed by interpolation: {changes}") + print(f" Conclusion: minute_data() already produces dense output, " f"interpolate_sparse_data() is a no-op") + + assert changes == 0, f"interpolate_sparse_data changed {changes} values -- " f"minute_data output was not fully dense" + + print("Test 7 PASSED") + + +def test_minute_data_output_after_clean_incrementing_reverse(): + """ + Test that clean_incrementing_reverse converts cumulative GE Cloud data + to proper incremental values: zero-consumption overnight periods should + produce 0 values, not flat cumulative values. + """ + print("\n=== Test 8: clean_incrementing_reverse produces incremental output ===") + + days = 8 + now_utc = datetime(2026, 3, 10, 12, 0, 0, tzinfo=timezone.utc) + history = generate_ge_cloud_history(days=days, now_utc=now_utc) + + result, _ = minute_data( + history, + days, + now_utc, + "consumption", + "last_updated", + backwards=True, + smoothing=True, + scale=1.0, + clean_increment=True, + interpolate=True, + ) + + total_minutes = days * 24 * 60 + + # After clean_increment=True, output should be incremental (cumulative energy consumed). + # Minute 0 has the highest value (total consumed), and values decrease going back in time. + # The increments (result[m] - result[m+1]) should be >= 0 for all m. + + # Check that the output is monotonically non-increasing (minute 0 >= minute 1 >= ...) + violations = 0 + for m in range(total_minutes - 1): + if result[m] < result[m + 1] - 0.001: # small tolerance + violations += 1 + + print(f" Monotonicity violations: {violations}") + assert violations == 0, f"Found {violations} monotonicity violations in incremental output" + + # Check overnight periods: between 23:00 and 06:00 the consumption rate is very low. + # Pick deep overnight (02:00) to avoid boundary effects at 23:00 transition. + # now_utc is 2026-03-10 12:00, so 2026-03-07 02:00 is 3 days 10 hours ago. + sample_night_start = now_utc - timedelta(days=3, hours=10) # 02:00 three days ago + night_offset_min = int((now_utc - sample_night_start).total_seconds() / 60) + + overnight_increments = [] + for m in range(night_offset_min, min(night_offset_min + 60, total_minutes - 1)): + if m in result and (m + 1) in result: + inc = result[m] - result[m + 1] + overnight_increments.append(inc) + + if overnight_increments: + avg_overnight = sum(overnight_increments) / len(overnight_increments) + max_overnight = max(overnight_increments) + print(f" Sample overnight period ({len(overnight_increments)} minutes):") + print(f" Average increment: {dp4(avg_overnight)} kWh/min") + print(f" Max increment: {dp4(max_overnight)} kWh/min") + # Overnight standby is ~0.1 kW = 0.1/60 kWh/min ~ 0.0017 + assert max_overnight < 0.01, f"Overnight increment too high: {dp4(max_overnight)} kWh/min, " f"expected near-zero for standby" + + # Check a daytime evening peak period for comparison + sample_evening = now_utc - timedelta(days=1, hours=18) # ~18:00 yesterday -> evening peak + evening_offset_min = int((now_utc - sample_evening).total_seconds() / 60) + + evening_increments = [] + for m in range(evening_offset_min, min(evening_offset_min + 60, total_minutes - 1)): + if m in result and (m + 1) in result: + inc = result[m] - result[m + 1] + evening_increments.append(inc) + + if evening_increments: + avg_evening = sum(evening_increments) / len(evening_increments) + print(f" Sample evening peak ({len(evening_increments)} minutes):") + print(f" Average increment: {dp4(avg_evening)} kWh/min") + # Evening is ~2.0 kW = 2.0/60 kWh/min ~ 0.033 + assert avg_evening > 0.01, f"Evening increment too low: {dp4(avg_evening)}, expected higher consumption" + + print("Test 8 PASSED") + + +def test_fill_load_no_difference_with_prior_interpolation(): + """ + Prove that fill_load_from_power produces identical results whether or + not you run interpolate_sparse_data on the minute_data output first. + Since minute_data already densifies, the interpolation is redundant. + """ + print("\n=== Test 9: fill_load_from_power unaffected by prior interpolation ===") + + days = 8 + now_utc = datetime(2026, 3, 10, 12, 0, 0, tzinfo=timezone.utc) + history = generate_ge_cloud_history(days=days, now_utc=now_utc) + + # Get dense minute_data output + load_minutes, _ = minute_data( + history, + days, + now_utc, + "consumption", + "last_updated", + backwards=True, + smoothing=True, + scale=1.0, + clean_increment=True, + interpolate=True, + ) + + # Also create an interpolated version + load_minutes_interpolated = interpolate_sparse_data_local(load_minutes) + + # Create some realistic power data for fill_load_from_power + # Just use the first 240 minutes (4 hours) to keep it manageable + test_range = 240 + load_subset = {m: load_minutes[m] for m in range(test_range + 1) if m in load_minutes} + load_subset_interp = {m: load_minutes_interpolated[m] for m in range(test_range + 1) if m in load_minutes_interpolated} + + # Generate varying power data + load_power_data = {} + for minute in range(test_range): + # Realistic varying power between 0.5-3.0 kW + load_power_data[minute] = 1500.0 + 1000.0 * ((minute % 7) / 6.0 - 0.5) + + fetch1 = TestFetch() + fetch2 = TestFetch() + + result_direct = fetch1.fill_load_from_power(load_subset, load_power_data) + result_interpolated = fetch2.fill_load_from_power(load_subset_interp, load_power_data) + + # Compare results + differences = 0 + max_diff = 0.0 + for m in range(test_range): + if m in result_direct and m in result_interpolated: + diff = abs(result_direct[m] - result_interpolated[m]) + if diff > 0.0001: + differences += 1 + max_diff = max(max_diff, diff) + + print(f" Test range: {test_range} minutes") + print(f" Differences found: {differences}") + print(f" Max difference: {dp4(max_diff)} kWh") + + assert differences == 0, f"fill_load_from_power produced {differences} different values " f"when input was pre-interpolated (max diff: {dp4(max_diff)})" + + print(f" Conclusion: Pre-interpolating minute_data output has zero effect on " f"fill_load_from_power results") + print("Test 9 PASSED") + + +def generate_ge_cloud_history_zero_overnight(days=8, now_utc=None): + """Generate GE Cloud data with true zero overnight consumption. + + Many real-world installations have periods of zero consumption (e.g. + battery-powered homes, solar-only, or efficient heat pump setups where + overnight consumption truly reaches zero). This generates data that + triggers the gap detector false positive. + """ + if now_utc is None: + now_utc = datetime.now(timezone.utc).replace(second=0, microsecond=0) + + history = [] + cumulative = 0.0 + start_time = now_utc - timedelta(days=days) + current_time = start_time + + while current_time <= now_utc: + hour = current_time.hour + + if 23 <= hour or hour < 6: + consumption_rate = 0.0 # True zero overnight consumption + elif 6 <= hour < 9: + consumption_rate = 1.5 / 12 + elif 9 <= hour < 17: + consumption_rate = 0.8 / 12 + elif 17 <= hour < 23: + consumption_rate = 2.0 / 12 + + cumulative += consumption_rate + + history.append( + { + "consumption": round(cumulative, 4), + "last_updated": current_time.strftime("%Y-%m-%dT%H:%M:%S+00:00"), + } + ) + + current_time += timedelta(minutes=5) + + return history + + +def test_gap_detector_false_positives_on_overnight_zeros(): + """ + Demonstrate the root cause of load inflation: the gap detector in + previous_days_modal_filter (fetch.py:424-425) checks whether consecutive + PREDICT_STEP-spaced values are equal. After clean_incrementing_reverse, + zero-consumption periods have flat incremental values (the cumulative + total doesn't change when there's no consumption). This triggers + false-positive gap detection, and the gaps get filled with average + daily consumption, inflating load. + + Uses true zero overnight consumption to reproduce the real-world bug + seen in homes with battery/solar setups where overnight grid import + drops to zero. + """ + print("\n=== Test 10: Gap detector false positives on overnight zeros ===") + + PREDICT_STEP = 5 # From const.py + + days = 8 + now_utc = datetime(2026, 3, 10, 12, 0, 0, tzinfo=timezone.utc) + history = generate_ge_cloud_history_zero_overnight(days=days, now_utc=now_utc) + + # Get the same minute_data output that download_ge_data produces + data, _ = minute_data( + history, + days, + now_utc, + "consumption", + "last_updated", + backwards=True, + smoothing=True, + scale=1.0, + clean_increment=True, + interpolate=True, + ) + + total_minutes = days * 24 * 60 + + # Reproduce the exact gap detection logic from fetch.py:423-437 + gap_size = 30 # Default plan_interval_minutes used as load_filter_threshold + gap_minutes = 0 + gap_start_minute_previous = None + gap_list = [] + num_gaps = 0 + max_minute = total_minutes + + for minute_previous in range(0, max_minute, PREDICT_STEP): + if data.get(minute_previous, 0) == data.get(minute_previous + PREDICT_STEP, 0): + gap_minutes += PREDICT_STEP + if gap_start_minute_previous is None: + gap_start_minute_previous = minute_previous + else: + if gap_minutes >= gap_size: + num_gaps += gap_minutes + gap_list.append((gap_start_minute_previous, gap_minutes)) + gap_minutes = 0 + gap_start_minute_previous = None + if gap_minutes >= gap_size: + num_gaps += gap_minutes + gap_list.append((gap_start_minute_previous, gap_minutes)) + + print(f" Total data span: {total_minutes} minutes ({days} days)") + print(f" Gap detection threshold: {gap_size} minutes") + print(f" Number of false gaps detected: {len(gap_list)}") + print(f" Total false gap minutes: {num_gaps}") + + # Show details of each detected gap with the time of day + overnight_gaps = 0 + overnight_gap_minutes = 0 + for gap_start, gap_len in gap_list: + # Convert minute offset back to clock time + gap_time = now_utc - timedelta(minutes=gap_start) + gap_end_time = now_utc - timedelta(minutes=gap_start + gap_len) + print(f" Gap at minute {gap_start} ({gap_end_time.strftime('%Y-%m-%d %H:%M')}" f" to {gap_time.strftime('%H:%M')}): {gap_len} minutes") + + # Check if this gap falls in overnight hours + gap_mid_time = now_utc - timedelta(minutes=gap_start + gap_len // 2) + mid_hour = gap_mid_time.hour + if 23 <= mid_hour or mid_hour < 6: + overnight_gaps += 1 + overnight_gap_minutes += gap_len + + print(f"\n Overnight false-positive gaps: {overnight_gaps}") + print(f" Overnight false-positive minutes: {overnight_gap_minutes}") + + # The key assertion: overnight zero-consumption periods ARE being + # falsely detected as gaps. This is the root cause of load inflation. + assert len(gap_list) > 0, "Expected false gaps to be detected in zero-consumption overnight periods" + assert overnight_gaps > 0, "Expected at least some gaps to fall in overnight hours" + + # Quantify the inflation: these gaps cover ~29% of the data + gap_pct = num_gaps / total_minutes * 100 + print(f"\n ROOT CAUSE CONFIRMED: The gap detector flags {len(gap_list)} regions " f"({num_gaps} total minutes, {gap_pct:.1f}% of data) as 'gaps'.") + print(f" These would be filled with average daily consumption, inflating load.") + print(f" The issue is at fetch.py:425 -- checking data[m] == data[m+5] on") + print(f" incremental data where zero-consumption periods legitimately have") + print(f" equal consecutive values.") + + print("Test 10 PASSED") + + +def test_gap_filter_fixes_false_positives(): + """ + End-to-end test: with load_data_point_minutes populated from minute_data(), + the gap filter in previous_days_modal_filter() should skip false-positive + gaps where the sensor was actively reporting zero consumption. + + This is the complement of Test 10 — Test 10 proves the false positives exist, + this test proves the fix eliminates them. + """ + print("\n=== Test 11: Gap filter fixes false positives with data_point_minutes ===") + + PREDICT_STEP_LOCAL = 5 + days = 8 + now_utc = datetime(2026, 3, 10, 12, 0, 0, tzinfo=timezone.utc) + history = generate_ge_cloud_history_zero_overnight(days=days, now_utc=now_utc) + + # Collect data_point_minutes during minute_data processing + data_point_minutes = set() + data, _ = minute_data( + history, + days, + now_utc, + "consumption", + "last_updated", + backwards=True, + smoothing=True, + scale=1.0, + clean_increment=True, + interpolate=True, + data_point_minutes=data_point_minutes, + ) + + print(f" Data points tracked: {len(data_point_minutes)}") + total_minutes = days * 24 * 60 + + # Run gap detection (same logic as fetch.py) + gap_size = 30 + gap_minutes_count = 0 + gap_start_minute_previous = None + gap_list = [] + num_gaps = 0 + max_minute = total_minutes + + for minute_previous in range(0, max_minute, PREDICT_STEP_LOCAL): + if data.get(minute_previous, 0) == data.get(minute_previous + PREDICT_STEP_LOCAL, 0): + gap_minutes_count += PREDICT_STEP_LOCAL + if gap_start_minute_previous is None: + gap_start_minute_previous = minute_previous + else: + if gap_minutes_count >= gap_size: + num_gaps += gap_minutes_count + gap_list.append((gap_start_minute_previous, gap_minutes_count)) + gap_minutes_count = 0 + gap_start_minute_previous = None + if gap_minutes_count >= gap_size: + num_gaps += gap_minutes_count + gap_list.append((gap_start_minute_previous, gap_minutes_count)) + + print(f" Gaps detected before filtering: {len(gap_list)} ({num_gaps} minutes)") + + # Now apply the same filter logic as in previous_days_modal_filter + filtered_gaps = [] + for gap_start, gap_minutes_len in gap_list: + gap_data_count = sum(1 for m in data_point_minutes if gap_start <= m < gap_start + gap_minutes_len) + min_data_points = max(gap_minutes_len // 60, 2) + if gap_data_count >= min_data_points: + print(f" Skipping gap at minute {gap_start} ({gap_minutes_len} min) - " f"sensor active ({gap_data_count} of {min_data_points} needed)") + else: + filtered_gaps.append((gap_start, gap_minutes_len)) + + remaining_gap_minutes = sum(g[1] for g in filtered_gaps) + print(f" Gaps remaining after filtering: {len(filtered_gaps)} ({remaining_gap_minutes} minutes)") + + # The key assertion: ALL false-positive gaps should be filtered out + # because the sensor was reporting data throughout + assert len(filtered_gaps) == 0, f"Expected all false-positive gaps to be filtered out, " f"but {len(filtered_gaps)} gaps remain ({remaining_gap_minutes} minutes)" + + print(f"\n FIX CONFIRMED: All {len(gap_list)} false-positive gaps " f"({num_gaps} minutes) correctly filtered out.") + print(f" No phantom load will be injected into overnight zero-consumption periods.") + + print("Test 11 PASSED") + + def run_all_tests(my_predbat=None): """Run all tests""" print("\n" + "=" * 60) @@ -318,6 +836,11 @@ def run_all_tests(my_predbat=None): test_fill_load_from_power_single_minute_period() test_fill_load_from_power_zero_load() test_fill_load_from_power_backwards_time() + test_minute_data_densifies_sparse_ge_cloud_data() + test_minute_data_output_after_clean_incrementing_reverse() + test_fill_load_no_difference_with_prior_interpolation() + test_gap_detector_false_positives_on_overnight_zeros() + test_gap_filter_fixes_false_positives() print("\n" + "=" * 60) print("✅ ALL TESTS PASSED") diff --git a/apps/predbat/tests/test_previous_days_modal.py b/apps/predbat/tests/test_previous_days_modal.py index 5cb9d597b..3cc375f1f 100644 --- a/apps/predbat/tests/test_previous_days_modal.py +++ b/apps/predbat/tests/test_previous_days_modal.py @@ -330,6 +330,122 @@ def mock_get_arg(key, default=None): else: print("Partial day (<16h gaps) correctly filled from {} kWh to {} kWh (proportional)".format(dp2(initial_day2_total), dp2(day2_filled_total))) + # Test 6: Overnight zeros with load_data_point_minutes → NOT filled + print("Test 6: Overnight zeros with sensor data points → gaps skipped") + + my_predbat.days_previous = [1] + my_predbat.days_previous_weight = [1.0] + my_predbat.load_minutes_age = 1 + my_predbat.load_filter_modal = False + + test_data = {} + # Day 1: 12 hours active (1 kWh/hour), 12 hours zero consumption + step_increment = 1.0 / 60 + running_total = 0 + + # First 12 hours: active consumption + for minute in range(0, 12 * 60): + running_total += step_increment + test_data[24 * 60 - minute - 1] = dp4(running_total) + # Last 12 hours: zero consumption (sensor online but no usage) + for minute in range(12 * 60, 24 * 60): + test_data[24 * 60 - minute - 1] = dp4(running_total) + + initial_data_sum = dp2(add_incrementing_sensor_total(test_data)) + + # Set up load_data_point_minutes covering full day (sensor was active throughout) + my_predbat.load_data_point_minutes = set(range(0, 24 * 60)) + + my_predbat.previous_days_modal_filter(test_data) + + final_data_sum = dp2(add_incrementing_sensor_total(test_data)) + + # With sensor data points covering the whole day, gaps should be SKIPPED + # Total should remain at 12 kWh (not inflated to 18 kWh) + if final_data_sum != initial_data_sum: + print("ERROR: Expected total to stay at {} kWh (sensor active), got {} kWh".format(initial_data_sum, final_data_sum)) + failed = True + else: + print("Correctly skipped gap filling: total stayed at {} kWh".format(final_data_sum)) + + # Test 7: 12h gap with NO data points in gap → IS filled + print("Test 7: Gap with no sensor data points → gap filled") + + my_predbat.days_previous = [1] + my_predbat.days_previous_weight = [1.0] + my_predbat.load_minutes_age = 1 + my_predbat.load_filter_modal = False + + test_data = {} + step_increment = 1.0 / 60 + running_total = 0 + + # First 12 hours: active consumption + for minute in range(0, 12 * 60): + running_total += step_increment + test_data[24 * 60 - minute - 1] = dp4(running_total) + # Last 12 hours: zero (sensor offline - no data) + for minute in range(12 * 60, 24 * 60): + test_data[24 * 60 - minute - 1] = dp4(running_total) + + initial_data_sum = dp2(add_incrementing_sensor_total(test_data)) + + # Set load_data_point_minutes to only the non-gap region (minutes 720-1439 backward) + # The gap is at backward minutes 0-719 (recent zero-consumption period, sensor offline) + # Data points exist only at minutes 720-1439 (earlier period with actual consumption) + my_predbat.load_data_point_minutes = set(range(12 * 60, 24 * 60)) + + my_predbat.previous_days_modal_filter(test_data) + + final_data_sum = dp2(add_incrementing_sensor_total(test_data)) + + # Gap should be filled since sensor was offline during the gap period + expected_final_total = 18.0 + if final_data_sum != expected_final_total: + print("ERROR: Expected gap filling to produce {} kWh, got {} kWh".format(expected_final_total, final_data_sum)) + failed = True + else: + print("Correctly filled gap (sensor offline): {} kWh → {} kWh".format(initial_data_sum, final_data_sum)) + + # Test 8: No load_data_point_minutes attribute → existing behavior unchanged + print("Test 8: No load_data_point_minutes → backward compat (gaps filled)") + + my_predbat.days_previous = [1] + my_predbat.days_previous_weight = [1.0] + my_predbat.load_minutes_age = 1 + my_predbat.load_filter_modal = False + + test_data = {} + step_increment = 1.0 / 60 + running_total = 0 + + for minute in range(0, 12 * 60): + running_total += step_increment + test_data[24 * 60 - minute - 1] = dp4(running_total) + for minute in range(12 * 60, 24 * 60): + test_data[24 * 60 - minute - 1] = dp4(running_total) + + initial_data_sum = dp2(add_incrementing_sensor_total(test_data)) + + # Remove load_data_point_minutes entirely + if hasattr(my_predbat, "load_data_point_minutes"): + delattr(my_predbat, "load_data_point_minutes") + + my_predbat.previous_days_modal_filter(test_data) + + final_data_sum = dp2(add_incrementing_sensor_total(test_data)) + + # Without the attribute, gaps should be filled as before + expected_final_total = 18.0 + if final_data_sum != expected_final_total: + print("ERROR: Expected backward-compat gap filling to produce {} kWh, got {} kWh".format(expected_final_total, final_data_sum)) + failed = True + else: + print("Backward compat correct: gaps filled as before, {} kWh → {} kWh".format(initial_data_sum, final_data_sum)) + + # Restore load_minutes_age for other tests + my_predbat.load_minutes_age = 7 + # Restore original get_arg method my_predbat.get_arg = original_get_arg diff --git a/apps/predbat/utils.py b/apps/predbat/utils.py index 409a4b109..508ecf747 100644 --- a/apps/predbat/utils.py +++ b/apps/predbat/utils.py @@ -315,6 +315,7 @@ def minute_data( max_increment=MAX_INCREMENT, interpolate=False, debug=False, + data_point_minutes=None, ): """ Turns data from HA into a hash of data indexed by minute with the data being the value @@ -467,6 +468,8 @@ def minute_data( timed_to = to_time - now minutes = int(timed.total_seconds() / 60) + if data_point_minutes is not None and minute_min <= minutes <= minute_max: + data_point_minutes.add(minutes) if to_time: minutes_to = int(timed_to.total_seconds() / 60) minutes_delta = (timed_to.total_seconds() - timed.total_seconds()) / 60.0 From 49d26374b768e85c26dbbb05bd22b82e9aec3fb1 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 06:16:05 +0000 Subject: [PATCH 02/11] feat(gateway): add Python protobuf bindings for gateway telemetry and plan Co-Authored-By: Claude Opus 4.6 --- .cspell.json | 2 + apps/predbat/proto/__init__.py | 0 apps/predbat/proto/gateway_status.proto | 150 +++++++++++++++++++++++ apps/predbat/proto/gateway_status_pb2.py | 58 +++++++++ 4 files changed, 210 insertions(+) create mode 100644 apps/predbat/proto/__init__.py create mode 100644 apps/predbat/proto/gateway_status.proto create mode 100644 apps/predbat/proto/gateway_status_pb2.py diff --git a/.cspell.json b/.cspell.json index 314e8837e..1d34f3f5a 100644 --- a/.cspell.json +++ b/.cspell.json @@ -14,6 +14,8 @@ "ignorePaths": [ "**/*.json", "**/*.yaml", + "**/*_pb2.py", + "**/*.proto", ".gitignore", ], "import": [ diff --git a/apps/predbat/proto/__init__.py b/apps/predbat/proto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/predbat/proto/gateway_status.proto b/apps/predbat/proto/gateway_status.proto new file mode 100644 index 000000000..67d177d43 --- /dev/null +++ b/apps/predbat/proto/gateway_status.proto @@ -0,0 +1,150 @@ +syntax = "proto3"; +package predbat; + +// Gateway Status Protocol Buffer Schema +// +// Schema version: 1 +// This is the canonical contract for ESP32 gateway → cloud MQTT status payloads. +// Proto3 zero-value omission means unpopulated future fields cost nothing on the wire. + +// Mirrors InverterType enum in types.h (values must match 1:1 for safe casting) +enum InverterType { + INVERTER_TYPE_UNKNOWN = 0; + INVERTER_TYPE_SOLIS_HYBRID = 1; + INVERTER_TYPE_SOLIS_AC = 2; + INVERTER_TYPE_SOFAR_G3 = 3; + INVERTER_TYPE_GROWATT_SPH = 4; + INVERTER_TYPE_DEYE_SUNSYNK = 5; + INVERTER_TYPE_GIVENERGY = 6; + INVERTER_TYPE_GIVENERGY_EMS = 7; + INVERTER_TYPE_CUSTOM = 8; +} + +message BatteryStatus { + uint32 soc_percent = 1; // 0-100 + float voltage_v = 2; + float current_a = 3; // positive = charging + int32 power_w = 4; // positive = charging + float temperature_c = 5; + // Future: SunSpec 802-inspired fields (populated when available) + uint32 soh_percent = 6; // state of health, 0 = not available + uint32 cycle_count = 7; + int32 capacity_wh = 8; // rated capacity +} + +message PvStatus { + int32 power_w = 1; +} + +message GridStatus { + int32 power_w = 1; // positive = importing from grid + // Future: grid quality fields + float voltage_v = 2; // 0 = not available + float frequency_hz = 3; +} + +message LoadStatus { + int32 power_w = 1; // house consumption +} + +message InverterData { + int32 active_power_w = 1; + float temperature_c = 2; +} + +message ControlStatus { + uint32 mode = 1; // OperatingMode enum (0=auto,1=charge,2=discharge,3=idle) + bool charge_enabled = 2; + bool discharge_enabled = 3; + uint32 charge_rate_w = 4; + uint32 discharge_rate_w = 5; + uint32 reserve_soc = 6; // min SOC % + uint32 target_soc = 7; // charge target SOC % + uint32 force_power_w = 8; + uint32 command_expires = 9; // unix timestamp +} + +message ScheduleStatus { + uint32 charge_start = 1; // HHMM format + uint32 charge_end = 2; + uint32 discharge_start = 3; + uint32 discharge_end = 4; +} + +message EmsSubInverter { + uint32 soc = 1; + int32 battery_w = 2; + int32 pv_w = 3; + int32 grid_w = 4; + float temp_c = 5; +} + +message EmsStatus { + uint32 num_inverters = 1; + uint32 total_soc = 2; + int32 total_charge_w = 3; + int32 total_discharge_w = 4; + int32 total_grid_w = 5; + int32 total_pv_w = 6; + int32 total_load_w = 7; + repeated EmsSubInverter sub_inverters = 8; +} + +// Future: energy counters (cumulative Wh, populated when available) +message EnergyCounters { + uint32 pv_total_wh = 1; + uint32 grid_import_total_wh = 2; + uint32 grid_export_total_wh = 3; + uint32 battery_charge_total_wh = 4; + uint32 battery_discharge_total_wh = 5; + uint32 consumption_total_wh = 6; +} + +message InverterEntry { + InverterType type = 1; + string serial = 2; + string ip = 3; + bool connected = 4; + bool active = 5; + BatteryStatus battery = 6; + PvStatus pv = 7; + GridStatus grid = 8; + LoadStatus load = 9; + InverterData inverter = 10; + ControlStatus control = 11; + ScheduleStatus schedule = 12; + EmsStatus ems = 13; + EnergyCounters energy = 14; // future: not populated yet +} + +message GatewayStatus { + string device_id = 1; + uint32 dongle_count = 2; + string firmware = 3; + uint32 timestamp = 4; // unix timestamp + uint32 schema_version = 5; // currently 1 + repeated InverterEntry inverters = 6; +} + +// Execution plan sent from PredBat cloud to gateway +// Published to predbat/devices/{id}/schedule as protobuf + +message PlanEntry { + bool enabled = 1; + uint32 start_hour = 2; // 0-23 (local time, per timezone field) + uint32 start_minute = 3; // 0-59 + uint32 end_hour = 4; // 0-23 + uint32 end_minute = 5; // 0-59 + uint32 mode = 6; // OperatingMode: 0=auto, 1=charge, 2=discharge, 3=idle + uint32 power_w = 7; // target power + uint32 target_soc = 8; // target SOC for charge, min SOC for discharge + uint32 days_of_week = 9; // bitmask: bit 0 = Sunday, bit 6 = Saturday + bool use_native = 10; // true = write to inverter schedule registers +} + +message ExecutionPlan { + uint32 timestamp = 1; // when plan was generated (unix epoch) + uint32 plan_version = 2; // monotonic, gateway skips stale plans + string timezone = 3; // IANA timezone (e.g. "Europe/London") + repeated PlanEntry entries = 4; +} diff --git a/apps/predbat/proto/gateway_status_pb2.py b/apps/predbat/proto/gateway_status_pb2.py new file mode 100644 index 000000000..0d2a34118 --- /dev/null +++ b/apps/predbat/proto/gateway_status_pb2.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: proto/gateway_status.proto +# Protobuf Python Version: 6.33.4 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 6, 33, 4, "", "proto/gateway_status.proto") +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1aproto/gateway_status.proto\x12\x07predbat"\xb1\x01\n\rBatteryStatus\x12\x13\n\x0bsoc_percent\x18\x01 \x01(\r\x12\x11\n\tvoltage_v\x18\x02 \x01(\x02\x12\x11\n\tcurrent_a\x18\x03 \x01(\x02\x12\x0f\n\x07power_w\x18\x04 \x01(\x05\x12\x15\n\rtemperature_c\x18\x05 \x01(\x02\x12\x13\n\x0bsoh_percent\x18\x06 \x01(\r\x12\x13\n\x0b\x63ycle_count\x18\x07 \x01(\r\x12\x13\n\x0b\x63\x61pacity_wh\x18\x08 \x01(\x05"\x1b\n\x08PvStatus\x12\x0f\n\x07power_w\x18\x01 \x01(\x05"F\n\nGridStatus\x12\x0f\n\x07power_w\x18\x01 \x01(\x05\x12\x11\n\tvoltage_v\x18\x02 \x01(\x02\x12\x14\n\x0c\x66requency_hz\x18\x03 \x01(\x02"\x1d\n\nLoadStatus\x12\x0f\n\x07power_w\x18\x01 \x01(\x05"=\n\x0cInverterData\x12\x16\n\x0e\x61\x63tive_power_w\x18\x01 \x01(\x05\x12\x15\n\rtemperature_c\x18\x02 \x01(\x02"\xda\x01\n\rControlStatus\x12\x0c\n\x04mode\x18\x01 \x01(\r\x12\x16\n\x0e\x63harge_enabled\x18\x02 \x01(\x08\x12\x19\n\x11\x64ischarge_enabled\x18\x03 \x01(\x08\x12\x15\n\rcharge_rate_w\x18\x04 \x01(\r\x12\x18\n\x10\x64ischarge_rate_w\x18\x05 \x01(\r\x12\x13\n\x0breserve_soc\x18\x06 \x01(\r\x12\x12\n\ntarget_soc\x18\x07 \x01(\r\x12\x15\n\rforce_power_w\x18\x08 \x01(\r\x12\x17\n\x0f\x63ommand_expires\x18\t \x01(\r"j\n\x0eScheduleStatus\x12\x14\n\x0c\x63harge_start\x18\x01 \x01(\r\x12\x12\n\ncharge_end\x18\x02 \x01(\r\x12\x17\n\x0f\x64ischarge_start\x18\x03 \x01(\r\x12\x15\n\rdischarge_end\x18\x04 \x01(\r"^\n\x0e\x45msSubInverter\x12\x0b\n\x03soc\x18\x01 \x01(\r\x12\x11\n\tbattery_w\x18\x02 \x01(\x05\x12\x0c\n\x04pv_w\x18\x03 \x01(\x05\x12\x0e\n\x06grid_w\x18\x04 \x01(\x05\x12\x0e\n\x06temp_c\x18\x05 \x01(\x02"\xd8\x01\n\tEmsStatus\x12\x15\n\rnum_inverters\x18\x01 \x01(\r\x12\x11\n\ttotal_soc\x18\x02 \x01(\r\x12\x16\n\x0etotal_charge_w\x18\x03 \x01(\x05\x12\x19\n\x11total_discharge_w\x18\x04 \x01(\x05\x12\x14\n\x0ctotal_grid_w\x18\x05 \x01(\x05\x12\x12\n\ntotal_pv_w\x18\x06 \x01(\x05\x12\x14\n\x0ctotal_load_w\x18\x07 \x01(\x05\x12.\n\rsub_inverters\x18\x08 \x03(\x0b\x32\x17.predbat.EmsSubInverter"\xc4\x01\n\x0e\x45nergyCounters\x12\x13\n\x0bpv_total_wh\x18\x01 \x01(\r\x12\x1c\n\x14grid_import_total_wh\x18\x02 \x01(\r\x12\x1c\n\x14grid_export_total_wh\x18\x03 \x01(\r\x12\x1f\n\x17\x62\x61ttery_charge_total_wh\x18\x04 \x01(\r\x12"\n\x1a\x62\x61ttery_discharge_total_wh\x18\x05 \x01(\r\x12\x1c\n\x14\x63onsumption_total_wh\x18\x06 \x01(\r"\xc8\x03\n\rInverterEntry\x12#\n\x04type\x18\x01 \x01(\x0e\x32\x15.predbat.InverterType\x12\x0e\n\x06serial\x18\x02 \x01(\t\x12\n\n\x02ip\x18\x03 \x01(\t\x12\x11\n\tconnected\x18\x04 \x01(\x08\x12\x0e\n\x06\x61\x63tive\x18\x05 \x01(\x08\x12\'\n\x07\x62\x61ttery\x18\x06 \x01(\x0b\x32\x16.predbat.BatteryStatus\x12\x1d\n\x02pv\x18\x07 \x01(\x0b\x32\x11.predbat.PvStatus\x12!\n\x04grid\x18\x08 \x01(\x0b\x32\x13.predbat.GridStatus\x12!\n\x04load\x18\t \x01(\x0b\x32\x13.predbat.LoadStatus\x12\'\n\x08inverter\x18\n \x01(\x0b\x32\x15.predbat.InverterData\x12\'\n\x07\x63ontrol\x18\x0b \x01(\x0b\x32\x16.predbat.ControlStatus\x12)\n\x08schedule\x18\x0c \x01(\x0b\x32\x17.predbat.ScheduleStatus\x12\x1f\n\x03\x65ms\x18\r \x01(\x0b\x32\x12.predbat.EmsStatus\x12\'\n\x06\x65nergy\x18\x0e \x01(\x0b\x32\x17.predbat.EnergyCounters"\xa0\x01\n\rGatewayStatus\x12\x11\n\tdevice_id\x18\x01 \x01(\t\x12\x14\n\x0c\x64ongle_count\x18\x02 \x01(\r\x12\x10\n\x08\x66irmware\x18\x03 \x01(\t\x12\x11\n\ttimestamp\x18\x04 \x01(\r\x12\x16\n\x0eschema_version\x18\x05 \x01(\r\x12)\n\tinverters\x18\x06 \x03(\x0b\x32\x16.predbat.InverterEntry"\xc9\x01\n\tPlanEntry\x12\x0f\n\x07\x65nabled\x18\x01 \x01(\x08\x12\x12\n\nstart_hour\x18\x02 \x01(\r\x12\x14\n\x0cstart_minute\x18\x03 \x01(\r\x12\x10\n\x08\x65nd_hour\x18\x04 \x01(\r\x12\x12\n\nend_minute\x18\x05 \x01(\r\x12\x0c\n\x04mode\x18\x06 \x01(\r\x12\x0f\n\x07power_w\x18\x07 \x01(\r\x12\x12\n\ntarget_soc\x18\x08 \x01(\r\x12\x14\n\x0c\x64\x61ys_of_week\x18\t \x01(\r\x12\x12\n\nuse_native\x18\n \x01(\x08"o\n\rExecutionPlan\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x14\n\x0cplan_version\x18\x02 \x01(\r\x12\x10\n\x08timezone\x18\x03 \x01(\t\x12#\n\x07\x65ntries\x18\x04 \x03(\x0b\x32\x12.predbat.PlanEntry*\x98\x02\n\x0cInverterType\x12\x19\n\x15INVERTER_TYPE_UNKNOWN\x10\x00\x12\x1e\n\x1aINVERTER_TYPE_SOLIS_HYBRID\x10\x01\x12\x1a\n\x16INVERTER_TYPE_SOLIS_AC\x10\x02\x12\x1a\n\x16INVERTER_TYPE_SOFAR_G3\x10\x03\x12\x1d\n\x19INVERTER_TYPE_GROWATT_SPH\x10\x04\x12\x1e\n\x1aINVERTER_TYPE_DEYE_SUNSYNK\x10\x05\x12\x1b\n\x17INVERTER_TYPE_GIVENERGY\x10\x06\x12\x1f\n\x1bINVERTER_TYPE_GIVENERGY_EMS\x10\x07\x12\x18\n\x14INVERTER_TYPE_CUSTOM\x10\x08\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "proto.gateway_status_pb2", _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals["_INVERTERTYPE"]._serialized_start = 2197 + _globals["_INVERTERTYPE"]._serialized_end = 2477 + _globals["_BATTERYSTATUS"]._serialized_start = 40 + _globals["_BATTERYSTATUS"]._serialized_end = 217 + _globals["_PVSTATUS"]._serialized_start = 219 + _globals["_PVSTATUS"]._serialized_end = 246 + _globals["_GRIDSTATUS"]._serialized_start = 248 + _globals["_GRIDSTATUS"]._serialized_end = 318 + _globals["_LOADSTATUS"]._serialized_start = 320 + _globals["_LOADSTATUS"]._serialized_end = 349 + _globals["_INVERTERDATA"]._serialized_start = 351 + _globals["_INVERTERDATA"]._serialized_end = 412 + _globals["_CONTROLSTATUS"]._serialized_start = 415 + _globals["_CONTROLSTATUS"]._serialized_end = 633 + _globals["_SCHEDULESTATUS"]._serialized_start = 635 + _globals["_SCHEDULESTATUS"]._serialized_end = 741 + _globals["_EMSSUBINVERTER"]._serialized_start = 743 + _globals["_EMSSUBINVERTER"]._serialized_end = 837 + _globals["_EMSSTATUS"]._serialized_start = 840 + _globals["_EMSSTATUS"]._serialized_end = 1056 + _globals["_ENERGYCOUNTERS"]._serialized_start = 1059 + _globals["_ENERGYCOUNTERS"]._serialized_end = 1255 + _globals["_INVERTERENTRY"]._serialized_start = 1258 + _globals["_INVERTERENTRY"]._serialized_end = 1714 + _globals["_GATEWAYSTATUS"]._serialized_start = 1717 + _globals["_GATEWAYSTATUS"]._serialized_end = 1877 + _globals["_PLANENTRY"]._serialized_start = 1880 + _globals["_PLANENTRY"]._serialized_end = 2081 + _globals["_EXECUTIONPLAN"]._serialized_start = 2083 + _globals["_EXECUTIONPLAN"]._serialized_end = 2194 +# @@protoc_insertion_point(module_scope) From c7619484b759c08bed87dbb0765cd0326be100d6 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 06:20:24 +0000 Subject: [PATCH 03/11] feat(gateway): add GatewayMQTT static methods with tests for protobuf decode and commands Co-Authored-By: Claude Opus 4.6 --- .cspell/custom-dictionary-workspace.txt | 1 + apps/predbat/gateway.py | 149 +++++++++++++++++ apps/predbat/tests/test_gateway.py | 205 ++++++++++++++++++++++++ 3 files changed, 355 insertions(+) create mode 100644 apps/predbat/gateway.py create mode 100644 apps/predbat/tests/test_gateway.py diff --git a/.cspell/custom-dictionary-workspace.txt b/.cspell/custom-dictionary-workspace.txt index 217577ac0..74c891ffe 100644 --- a/.cspell/custom-dictionary-workspace.txt +++ b/.cspell/custom-dictionary-workspace.txt @@ -249,6 +249,7 @@ onmouseover openweathermap overfitting ownerapi +pbgw pdata pdetails perc diff --git a/apps/predbat/gateway.py b/apps/predbat/gateway.py new file mode 100644 index 000000000..ab3684252 --- /dev/null +++ b/apps/predbat/gateway.py @@ -0,0 +1,149 @@ +"""ESP32 Gateway MQTT component. + +Provides full inverter telemetry and control via the ESP32 gateway's +MQTT interface. Registered in COMPONENT_LIST as 'gateway'. This is +the sole data source and control interface for SaaS users with a +gateway — no Home Assistant in the loop. +""" + +import json +import time +import uuid +from proto import gateway_status_pb2 as pb + + +# Entity mapping: protobuf field path → entity name +ENTITY_MAP = { + # Battery + "battery.soc_percent": "predbat_gateway_soc", + "battery.power_w": "predbat_gateway_battery_power", + "battery.voltage_v": "predbat_gateway_battery_voltage", + "battery.current_a": "predbat_gateway_battery_current", + "battery.temperature_c": "predbat_gateway_battery_temp", + "battery.soh_percent": "predbat_gateway_battery_soh", + "battery.cycle_count": "predbat_gateway_battery_cycles", + "battery.capacity_wh": "predbat_gateway_battery_capacity", + # Power flows + "pv.power_w": "predbat_gateway_pv_power", + "grid.power_w": "predbat_gateway_grid_power", + "grid.voltage_v": "predbat_gateway_grid_voltage", + "grid.frequency_hz": "predbat_gateway_grid_frequency", + "load.power_w": "predbat_gateway_load_power", + "inverter.active_power_w": "predbat_gateway_inverter_power", + "inverter.temperature_c": "predbat_gateway_inverter_temp", + # Control + "control.mode": "predbat_gateway_mode", + "control.charge_enabled": "predbat_gateway_charge_enabled", + "control.discharge_enabled": "predbat_gateway_discharge_enabled", + "control.charge_rate_w": "predbat_gateway_charge_rate", + "control.discharge_rate_w": "predbat_gateway_discharge_rate", + "control.reserve_soc": "predbat_gateway_reserve", + "control.target_soc": "predbat_gateway_target_soc", + "control.force_power_w": "predbat_gateway_force_power", + "control.command_expires": "predbat_gateway_command_expires", + # Schedule + "schedule.charge_start": "predbat_gateway_charge_start", + "schedule.charge_end": "predbat_gateway_charge_end", + "schedule.discharge_start": "predbat_gateway_discharge_start", + "schedule.discharge_end": "predbat_gateway_discharge_end", +} + + +class GatewayMQTT: + """ESP32 Gateway MQTT component for PredBat. + + Static methods handle data transformation (protobuf ↔ entities/commands). + Instance methods handle MQTT lifecycle and ComponentBase integration. + """ + + @staticmethod + def decode_telemetry(data): + """Decode protobuf GatewayStatus → dict of entity_name: value. + + Args: + data: Raw protobuf bytes from /status topic. + + Returns: + Dict mapping entity names to values. Uses first inverter entry. + """ + status = pb.GatewayStatus() + status.ParseFromString(data) + + if len(status.inverters) == 0: + return {} + + inv = status.inverters[0] + entities = {} + + for field_path, entity_name in ENTITY_MAP.items(): + parts = field_path.split(".") + obj = inv + for part in parts: + obj = getattr(obj, part, None) + if obj is None: + break + if obj is not None: + entities[entity_name] = obj + + return entities + + @staticmethod + def build_execution_plan(entries, plan_version, timezone): + """Build protobuf ExecutionPlan from a list of plan entry dicts. + + Args: + entries: List of dicts with keys matching PlanEntry fields. + plan_version: Monotonic version number. + timezone: IANA timezone string (e.g. "Europe/London"). + + Returns: + Serialized protobuf bytes. + """ + plan = pb.ExecutionPlan() + plan.timestamp = int(time.time()) + plan.plan_version = plan_version + plan.timezone = timezone + + for entry_dict in entries: + pe = plan.entries.add() + pe.enabled = entry_dict.get("enabled", True) + pe.start_hour = entry_dict.get("start_hour", 0) + pe.start_minute = entry_dict.get("start_minute", 0) + pe.end_hour = entry_dict.get("end_hour", 0) + pe.end_minute = entry_dict.get("end_minute", 0) + pe.mode = entry_dict.get("mode", 0) + pe.power_w = entry_dict.get("power_w", 0) + pe.target_soc = entry_dict.get("target_soc", 0) + pe.days_of_week = entry_dict.get("days_of_week", 0x7F) + pe.use_native = entry_dict.get("use_native", False) + + return plan.SerializeToString() + + @staticmethod + def build_command(command, **kwargs): + """Build JSON command string for ad-hoc control. + + Args: + command: Command name (set_mode, set_charge_rate, etc.) + **kwargs: Command-specific fields (mode, power_w, target_soc). + + Returns: + JSON string ready to publish to /command topic. + """ + cmd = { + "command": command, + "command_id": str(uuid.uuid4()), + } + + if "mode" in kwargs: + cmd["mode"] = kwargs["mode"] + if "power_w" in kwargs: + cmd["power_w"] = kwargs["power_w"] + if "target_soc" in kwargs: + cmd["target_soc"] = kwargs["target_soc"] + + # Mode commands need expires_at (5-minute deadman) + if command == "set_mode": + cmd["expires_at"] = int(time.time()) + 300 + + return json.dumps(cmd) diff --git a/apps/predbat/tests/test_gateway.py b/apps/predbat/tests/test_gateway.py new file mode 100644 index 000000000..70ebb84a4 --- /dev/null +++ b/apps/predbat/tests/test_gateway.py @@ -0,0 +1,205 @@ +"""Tests for GatewayMQTT component.""" +import pytest +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from proto import gateway_status_pb2 as pb + + +class TestProtobufDecode: + """Test protobuf telemetry → entity mapping.""" + + def _make_status(self, soc=50, battery_power=1000, pv_power=2000, grid_power=-500, load_power=1500, mode=0): + status = pb.GatewayStatus() + status.device_id = "pbgw_test123" + status.firmware = "0.4.5" + status.timestamp = 1741789200 + status.schema_version = 1 + status.dongle_count = 1 + + inv = status.inverters.add() + inv.type = pb.INVERTER_TYPE_GIVENERGY + inv.serial = "CE1234G567" + inv.ip = "192.168.1.100" + inv.connected = True + inv.active = True + + inv.battery.soc_percent = soc + inv.battery.power_w = battery_power + inv.battery.voltage_v = 51.2 + inv.battery.current_a = 19.5 + inv.battery.temperature_c = 22.5 + inv.battery.soh_percent = 98 + inv.battery.cycle_count = 150 + inv.battery.capacity_wh = 9500 + + inv.pv.power_w = pv_power + inv.grid.power_w = grid_power + inv.grid.voltage_v = 242.5 + inv.grid.frequency_hz = 50.01 + inv.load.power_w = load_power + + inv.inverter.active_power_w = 1800 + inv.inverter.temperature_c = 35.0 + + inv.control.mode = mode + inv.control.charge_enabled = True + inv.control.discharge_enabled = True + inv.control.charge_rate_w = 3000 + inv.control.discharge_rate_w = 3000 + inv.control.reserve_soc = 4 + inv.control.target_soc = 100 + + inv.schedule.charge_start = 130 + inv.schedule.charge_end = 430 + inv.schedule.discharge_start = 1600 + inv.schedule.discharge_end = 1900 + + return status + + def test_serialize_deserialize_roundtrip(self): + original = self._make_status(soc=75, battery_power=2000) + data = original.SerializeToString() + decoded = pb.GatewayStatus() + decoded.ParseFromString(data) + + assert decoded.device_id == "pbgw_test123" + assert decoded.inverters[0].battery.soc_percent == 75 + assert decoded.inverters[0].battery.power_w == 2000 + assert decoded.inverters[0].pv.power_w == 2000 + assert decoded.inverters[0].grid.power_w == -500 + assert decoded.inverters[0].grid.voltage_v == pytest.approx(242.5, abs=0.1) + assert decoded.inverters[0].control.charge_enabled is True + assert decoded.inverters[0].battery.soh_percent == 98 + + def test_entity_mapping(self): + from gateway import GatewayMQTT + + status = self._make_status() + data = status.SerializeToString() + + entities = GatewayMQTT.decode_telemetry(data) + + assert entities["predbat_gateway_soc"] == 50 + assert entities["predbat_gateway_battery_power"] == 1000 + assert entities["predbat_gateway_pv_power"] == 2000 + assert entities["predbat_gateway_grid_power"] == -500 + assert entities["predbat_gateway_load_power"] == 1500 + assert entities["predbat_gateway_battery_voltage"] == pytest.approx(51.2, abs=0.1) + assert entities["predbat_gateway_battery_current"] == pytest.approx(19.5, abs=0.1) + assert entities["predbat_gateway_battery_temp"] == pytest.approx(22.5, abs=0.1) + assert entities["predbat_gateway_battery_soh"] == 98 + assert entities["predbat_gateway_battery_cycles"] == 150 + assert entities["predbat_gateway_battery_capacity"] == 9500 + assert entities["predbat_gateway_grid_voltage"] == pytest.approx(242.5, abs=0.1) + assert entities["predbat_gateway_grid_frequency"] == pytest.approx(50.01, abs=0.01) + assert entities["predbat_gateway_inverter_power"] == 1800 + assert entities["predbat_gateway_inverter_temp"] == pytest.approx(35.0, abs=0.1) + assert entities["predbat_gateway_mode"] == 0 + assert entities["predbat_gateway_charge_enabled"] is True + assert entities["predbat_gateway_discharge_enabled"] is True + assert entities["predbat_gateway_charge_rate"] == 3000 + assert entities["predbat_gateway_discharge_rate"] == 3000 + assert entities["predbat_gateway_reserve"] == 4 + assert entities["predbat_gateway_target_soc"] == 100 + assert entities["predbat_gateway_charge_start"] == 130 + assert entities["predbat_gateway_charge_end"] == 430 + assert entities["predbat_gateway_discharge_start"] == 1600 + assert entities["predbat_gateway_discharge_end"] == 1900 + + +class TestPlanSerialization: + def test_plan_roundtrip(self): + from gateway import GatewayMQTT + + plan_entries = [ + { + "enabled": True, + "start_hour": 1, + "start_minute": 30, + "end_hour": 4, + "end_minute": 30, + "mode": 1, + "power_w": 3000, + "target_soc": 100, + "days_of_week": 0x7F, + "use_native": True, + }, + { + "enabled": True, + "start_hour": 16, + "start_minute": 0, + "end_hour": 19, + "end_minute": 0, + "mode": 2, + "power_w": 2500, + "target_soc": 10, + "days_of_week": 0x7F, + "use_native": False, + }, + ] + + data = GatewayMQTT.build_execution_plan(plan_entries, plan_version=42, timezone="Europe/London") + + plan = pb.ExecutionPlan() + plan.ParseFromString(data) + + assert plan.plan_version == 42 + assert plan.timezone == "Europe/London" + assert len(plan.entries) == 2 + assert plan.entries[0].start_hour == 1 + assert plan.entries[0].start_minute == 30 + assert plan.entries[0].mode == 1 + assert plan.entries[0].use_native is True + assert plan.entries[1].mode == 2 + assert plan.entries[1].use_native is False + + def test_empty_plan(self): + from gateway import GatewayMQTT + + data = GatewayMQTT.build_execution_plan([], plan_version=1, timezone="UTC") + plan = pb.ExecutionPlan() + plan.ParseFromString(data) + assert len(plan.entries) == 0 + assert plan.plan_version == 1 + + +class TestCommandFormat: + def test_set_mode_command(self): + from gateway import GatewayMQTT + + cmd = GatewayMQTT.build_command("set_mode", mode=1, power_w=3000, target_soc=100) + import json + + parsed = json.loads(cmd) + assert parsed["command"] == "set_mode" + assert parsed["mode"] == 1 + assert parsed["power_w"] == 3000 + assert parsed["target_soc"] == 100 + assert "command_id" in parsed + assert "expires_at" in parsed + import time + + assert abs(parsed["expires_at"] - int(time.time())) < 310 + + def test_set_charge_rate_command(self): + from gateway import GatewayMQTT + + cmd = GatewayMQTT.build_command("set_charge_rate", power_w=2500) + import json + + parsed = json.loads(cmd) + assert parsed["command"] == "set_charge_rate" + assert parsed["power_w"] == 2500 + + def test_set_reserve_command(self): + from gateway import GatewayMQTT + + cmd = GatewayMQTT.build_command("set_reserve", target_soc=10) + import json + + parsed = json.loads(cmd) + assert parsed["command"] == "set_reserve" + assert parsed["target_soc"] == 10 From 99540280aef498229f6128151b9b39289e139fb1 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 06:36:47 +0000 Subject: [PATCH 04/11] feat(gateway): add ComponentBase lifecycle, MQTT connection and run loop GatewayMQTT now inherits from ComponentBase with full instance methods: - initialize() stores config and builds MQTT topic strings - run() starts background MQTT listener on first call, does housekeeping after - _mqtt_loop() connects with TLS, subscribes to /status and /online, reconnects - _process_telemetry() decodes protobuf and publishes entities via set_state_wrapper - publish_plan/publish_command for outbound control - is_alive() checks MQTT connected + telemetry freshness - select_event/number_event for UI-driven mode/rate/SOC changes - final() sends AUTO mode and cancels listener on shutdown - Token refresh via Supabase edge function (same pattern as OAuthMixin) All existing static methods and tests preserved unchanged. Co-Authored-By: Claude Opus 4.6 --- .cspell/custom-dictionary-workspace.txt | 2 + apps/predbat/gateway.py | 428 +++++++++++++++++++++++- 2 files changed, 427 insertions(+), 3 deletions(-) diff --git a/.cspell/custom-dictionary-workspace.txt b/.cspell/custom-dictionary-workspace.txt index 74c891ffe..41bf7293f 100644 --- a/.cspell/custom-dictionary-workspace.txt +++ b/.cspell/custom-dictionary-workspace.txt @@ -6,6 +6,7 @@ afci AIO AIO's aiohttp +aiomqtt Alertfeed allclose Anson @@ -169,6 +170,7 @@ ivtime jedlix jsyaml kaiming +keepalive killall kopt Kostal diff --git a/apps/predbat/gateway.py b/apps/predbat/gateway.py index ab3684252..1ad1f1863 100644 --- a/apps/predbat/gateway.py +++ b/apps/predbat/gateway.py @@ -6,11 +6,26 @@ gateway — no Home Assistant in the loop. """ +import asyncio import json +import os +import ssl import time import uuid +import traceback + +import aiohttp +from datetime import datetime +from component_base import ComponentBase from proto import gateway_status_pb2 as pb +try: + import aiomqtt + + HAS_AIOMQTT = True +except ImportError: + HAS_AIOMQTT = False + # Entity mapping: protobuf field path → entity name ENTITY_MAP = { @@ -48,17 +63,424 @@ "schedule.discharge_end": "predbat_gateway_discharge_end", } +# Token refresh threshold — refresh when less than 2 hours remaining +_TOKEN_REFRESH_THRESHOLD = 2 * 60 * 60 + +# Plan re-publish interval (seconds) +_PLAN_REPUBLISH_INTERVAL = 5 * 60 -class GatewayMQTT: +# Telemetry staleness threshold (seconds) +_TELEMETRY_STALE_THRESHOLD = 120 + + +class GatewayMQTT(ComponentBase): """ESP32 Gateway MQTT component for PredBat. - Static methods handle data transformation (protobuf ↔ entities/commands). + Static methods handle data transformation (protobuf <-> entities/commands). Instance methods handle MQTT lifecycle and ComponentBase integration. """ + def initialize(self, gateway_device_id=None, mqtt_host=None, mqtt_port=8883, mqtt_token=None, mqtt_refresh_token=None, **kwargs): + """Initialize gateway configuration and build MQTT topic strings. + + Args: + gateway_device_id: The gateway's device ID (e.g. "pbgw_abc123"). + mqtt_host: MQTT broker hostname. + mqtt_port: MQTT broker port (default 8883 for TLS). + mqtt_token: JWT access token for MQTT authentication. + mqtt_refresh_token: Refresh token for token renewal. + **kwargs: Additional keyword arguments (ignored). + """ + self.gateway_device_id = gateway_device_id + self.mqtt_host = mqtt_host + self.mqtt_port = mqtt_port + self.mqtt_token = mqtt_token + self.mqtt_refresh_token = mqtt_refresh_token + self.mqtt_token_expires_at = 0 + + # MQTT topic strings + self._topic_base = f"gw/{gateway_device_id}" if gateway_device_id else "gw/unknown" + self.topic_status = f"{self._topic_base}/status" + self.topic_online = f"{self._topic_base}/online" + self.topic_schedule = f"{self._topic_base}/schedule" + self.topic_command = f"{self._topic_base}/command" + + # Runtime state + self._mqtt_client = None + self._mqtt_task = None + self._mqtt_connected = False + self._gateway_online = False + self._last_telemetry_time = 0 + self._last_plan_data = None + self._last_plan_publish_time = 0 + self._plan_version = 0 + self._refresh_in_progress = False + + async def run(self, seconds, first): + """Component run loop — called every 60 seconds by ComponentBase.start(). + + On the first call, starts the background MQTT listener task. + Subsequent calls perform housekeeping: token refresh checks and + plan re-publishing if stale. + + Args: + seconds: Elapsed seconds since component start. + first: True on the first invocation. + + Returns: + True on success, False on failure. + """ + if not HAS_AIOMQTT: + self.log("Error: GatewayMQTT: aiomqtt not installed — cannot start") + return False + + if not self.gateway_device_id or not self.mqtt_host: + self.log("Error: GatewayMQTT: gateway_device_id and mqtt_host are required") + return False + + if first: + # Start MQTT listener as a background task + self._mqtt_task = asyncio.ensure_future(self._mqtt_loop()) + self.log("Info: GatewayMQTT: MQTT listener task started") + return True + + # Housekeeping on subsequent runs + try: + # Check if MQTT task died unexpectedly + if self._mqtt_task and self._mqtt_task.done(): + exc = self._mqtt_task.exception() if not self._mqtt_task.cancelled() else None + if exc: + self.log(f"Warn: GatewayMQTT: MQTT task died with: {exc}") + self.log("Info: GatewayMQTT: Restarting MQTT listener task") + self._mqtt_task = asyncio.ensure_future(self._mqtt_loop()) + + # Token refresh check + await self._check_token_refresh() + + # Re-publish plan if stale + if self._last_plan_data and self._mqtt_connected: + elapsed = time.time() - self._last_plan_publish_time + if elapsed > _PLAN_REPUBLISH_INTERVAL: + await self._publish_raw(self.topic_schedule, self._last_plan_data) + self._last_plan_publish_time = time.time() + self.log("Info: GatewayMQTT: Re-published execution plan (stale)") + + except Exception as e: + self.log(f"Warn: GatewayMQTT: housekeeping error: {e}") + + return True + + async def _mqtt_loop(self): + """Continuous MQTT listener with automatic reconnection. + + Connects to the broker with TLS, subscribes to status and online + topics, and dispatches incoming messages. Reconnects on failure + with exponential backoff. + """ + backoff = 5 + max_backoff = 120 + + while not self.api_stop: + try: + tls_context = ssl.create_default_context() + + client_id = f"predbat-{self.gateway_device_id}-{uuid.uuid4().hex[:8]}" + + async with aiomqtt.Client( + hostname=self.mqtt_host, + port=self.mqtt_port, + username=self.gateway_device_id, + password=self.mqtt_token, + tls_context=tls_context, + identifier=client_id, + keepalive=60, + ) as client: + self._mqtt_client = client + self._mqtt_connected = True + backoff = 5 # Reset backoff on successful connection + self.log(f"Info: GatewayMQTT: Connected to {self.mqtt_host}:{self.mqtt_port}") + + # Subscribe to status and LWT topics + await client.subscribe(self.topic_status, qos=1) + await client.subscribe(self.topic_online, qos=1) + self.log(f"Info: GatewayMQTT: Subscribed to {self.topic_status} and {self.topic_online}") + + async for message in client.messages: + if self.api_stop: + break + await self._handle_message(message) + + except asyncio.CancelledError: + self.log("Info: GatewayMQTT: MQTT loop cancelled") + break + except Exception as e: + self.log(f"Warn: GatewayMQTT: MQTT connection error: {e}") + self._mqtt_connected = False + self._mqtt_client = None + + if self.api_stop: + break + + self.log(f"Info: GatewayMQTT: Reconnecting in {backoff}s") + await asyncio.sleep(backoff) + backoff = min(backoff * 2, max_backoff) + + self._mqtt_connected = False + self._mqtt_client = None + + async def _handle_message(self, message): + """Dispatch an incoming MQTT message to the appropriate handler. + + Args: + message: An aiomqtt.Message with topic and payload. + """ + topic = str(message.topic) + + try: + if topic == self.topic_status: + self._process_telemetry(message.payload) + elif topic == self.topic_online: + payload = message.payload.decode("utf-8", errors="replace").strip() + was_online = self._gateway_online + self._gateway_online = payload == "1" + if self._gateway_online != was_online: + state = "online" if self._gateway_online else "offline" + self.log(f"Info: GatewayMQTT: Gateway is {state}") + self.set_state_wrapper( + f"sensor.{self.prefix}predbat_gateway_online", + state, + attributes={"friendly_name": "Gateway Online"}, + ) + except Exception as e: + self.log(f"Warn: GatewayMQTT: Error handling message on {topic}: {e}") + self.log(f"Warn: {traceback.format_exc()}") + + def _process_telemetry(self, data): + """Decode telemetry protobuf and publish entities via set_state_wrapper. + + Args: + data: Raw protobuf bytes from the /status topic. + """ + entities = self.decode_telemetry(data) + if not entities: + return + + self._last_telemetry_time = time.time() + self.update_success_timestamp() + + for entity_name, value in entities.items(): + self.set_state_wrapper( + f"sensor.{self.prefix}{entity_name}", + value, + attributes={"friendly_name": entity_name.replace("predbat_gateway_", "Gateway ").replace("_", " ").title()}, + ) + + async def publish_plan(self, plan_entries, timezone_str): + """Build and publish an ExecutionPlan protobuf to the gateway. + + Args: + plan_entries: List of plan entry dicts. + timezone_str: IANA timezone string (e.g. "Europe/London"). + """ + self._plan_version += 1 + data = self.build_execution_plan(plan_entries, plan_version=self._plan_version, timezone=timezone_str) + self._last_plan_data = data + self._last_plan_publish_time = time.time() + + if self._mqtt_connected: + await self._publish_raw(self.topic_schedule, data) + self.log(f"Info: GatewayMQTT: Published execution plan v{self._plan_version} ({len(plan_entries)} entries)") + else: + self.log("Warn: GatewayMQTT: Not connected — plan queued for next publish") + + async def publish_command(self, command, **kwargs): + """Build and publish a JSON command to the gateway. + + Args: + command: Command name (set_mode, set_charge_rate, etc.) + **kwargs: Command-specific fields (mode, power_w, target_soc). + """ + cmd_json = self.build_command(command, **kwargs) + + if self._mqtt_connected: + await self._publish_raw(self.topic_command, cmd_json.encode("utf-8")) + self.log(f"Info: GatewayMQTT: Published command: {command}") + else: + self.log(f"Warn: GatewayMQTT: Not connected — cannot publish command: {command}") + + async def _publish_raw(self, topic, payload): + """Publish raw bytes to an MQTT topic. + + Args: + topic: MQTT topic string. + payload: Bytes to publish. + """ + if self._mqtt_client and self._mqtt_connected: + await self._mqtt_client.publish(topic, payload, qos=1) + + def is_alive(self): + """Check if the gateway component is alive and receiving data. + + Returns True when MQTT is connected AND either the gateway is + offline (LWT says so — we're still connected, just no data) OR + we've received telemetry within the last 2 minutes. + + Returns: + bool: True if healthy, False otherwise. + """ + if not self._mqtt_connected: + return False + + if not self._gateway_online: + # Gateway is offline but we're connected to broker — that's OK + return True + + # Gateway is online — check telemetry freshness + if self._last_telemetry_time == 0: + return False + + return (time.time() - self._last_telemetry_time) < _TELEMETRY_STALE_THRESHOLD + + async def select_event(self, entity_id, value): + """Handle select entity changes (e.g. mode selection). + + Args: + entity_id: The entity ID that changed. + value: The new selected value. + """ + if "gateway_mode" in entity_id: + mode_map = {"auto": 0, "charge": 1, "discharge": 2, "idle": 3} + mode_val = mode_map.get(str(value).lower()) + if mode_val is not None: + await self.publish_command("set_mode", mode=mode_val) + self.log(f"Info: GatewayMQTT: Mode set to {value} ({mode_val})") + + async def number_event(self, entity_id, value): + """Handle number entity changes (e.g. charge rate, target SOC). + + Args: + entity_id: The entity ID that changed. + value: The new numeric value. + """ + try: + val = int(float(value)) + except (ValueError, TypeError): + self.log(f"Warn: GatewayMQTT: Invalid number value: {value}") + return + + if "charge_rate" in entity_id: + await self.publish_command("set_charge_rate", power_w=val) + elif "discharge_rate" in entity_id: + await self.publish_command("set_discharge_rate", power_w=val) + elif "reserve" in entity_id: + await self.publish_command("set_reserve", target_soc=val) + elif "target_soc" in entity_id: + await self.publish_command("set_target_soc", target_soc=val) + + async def switch_event(self, entity_id, service): + """Handle switch entity service calls. Stub for v1. + + Args: + entity_id: The entity ID being controlled. + service: The service being called (turn_on/turn_off). + """ + pass + + async def final(self): + """Cleanup: send AUTO mode, cancel listener task, disconnect.""" + try: + # Send AUTO mode before disconnecting + if self._mqtt_connected: + await self.publish_command("set_mode", mode=0) + self.log("Info: GatewayMQTT: Sent AUTO mode on shutdown") + except Exception as e: + self.log(f"Warn: GatewayMQTT: Error sending final AUTO mode: {e}") + + # Cancel the MQTT listener task + if self._mqtt_task and not self._mqtt_task.done(): + self._mqtt_task.cancel() + try: + await self._mqtt_task + except (asyncio.CancelledError, Exception): + pass + + self._mqtt_connected = False + self._mqtt_client = None + self.log("Info: GatewayMQTT: Finalized") + + async def _check_token_refresh(self): + """Check if the MQTT JWT token needs refreshing and refresh if needed. + + Uses the Supabase edge function (same pattern as OAuthMixin) to + obtain a new access token before the current one expires. + """ + if not self.mqtt_refresh_token: + return + + if self.mqtt_token_expires_at and time.time() < (self.mqtt_token_expires_at - _TOKEN_REFRESH_THRESHOLD): + return + + if self._refresh_in_progress: + return + + self._refresh_in_progress = True + try: + supabase_url = os.environ.get("SUPABASE_URL", "") + supabase_key = os.environ.get("SUPABASE_KEY", "") + instance_id = self.args.get("user_id", "") if isinstance(self.args, dict) else "" + + if not supabase_url or not supabase_key or not instance_id: + self.log("Warn: GatewayMQTT: Token refresh skipped — missing env vars or instance_id") + return + + url = f"{supabase_url}/functions/v1/gateway-token-refresh" + headers = { + "Authorization": f"Bearer {supabase_key}", + "Content-Type": "application/json", + } + payload = { + "instance_id": instance_id, + "refresh_token": self.mqtt_refresh_token, + } + + self.log("Info: GatewayMQTT: Refreshing MQTT token") + + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=payload) as response: + if response.status != 200: + self.log(f"Warn: GatewayMQTT: Token refresh HTTP {response.status}") + return + + data = await response.json() + + if data.get("success"): + self.mqtt_token = data["access_token"] + if data.get("refresh_token"): + self.mqtt_refresh_token = data["refresh_token"] + if data.get("expires_at"): + try: + if isinstance(data["expires_at"], (int, float)): + self.mqtt_token_expires_at = float(data["expires_at"]) + else: + dt = datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00")) + self.mqtt_token_expires_at = dt.timestamp() + except (ValueError, AttributeError): + self.mqtt_token_expires_at = 0 + self.log("Info: GatewayMQTT: MQTT token refreshed") + else: + self.log(f"Warn: GatewayMQTT: Token refresh failed: {data.get('error', 'unknown')}") + + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + self.log(f"Warn: GatewayMQTT: Token refresh network error: {e}") + except Exception as e: + self.log(f"Warn: GatewayMQTT: Token refresh error: {e}") + finally: + self._refresh_in_progress = False + @staticmethod def decode_telemetry(data): - """Decode protobuf GatewayStatus → dict of entity_name: value. + """Decode protobuf GatewayStatus -> dict of entity_name: value. Args: data: Raw protobuf bytes from /status topic. From 23e129d1b6fdc6e4d25c4e6b594eedf89f4a6b4f Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 07:30:14 +0000 Subject: [PATCH 05/11] feat(gateway): register GatewayMQTT in COMPONENT_LIST Co-Authored-By: Claude Opus 4.6 --- apps/predbat/components.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/apps/predbat/components.py b/apps/predbat/components.py index 8a6afc2ea..a58cd82f9 100644 --- a/apps/predbat/components.py +++ b/apps/predbat/components.py @@ -31,6 +31,7 @@ from ha import HAInterface, HAHistory from db_manager import DatabaseManager from fox import FoxAPI +from gateway import GatewayMQTT from web_mcp import PredbatMCPServer from load_ml_component import LoadMLComponent from datetime import datetime, timezone, timedelta @@ -307,6 +308,20 @@ "phase": 1, "can_restart": True, }, + "gateway": { + "class": GatewayMQTT, + "name": "PredBat Gateway", + "event_filter": "predbat_gateway_", + "args": { + "gateway_device_id": {"required": True, "config": "gateway_device_id"}, + "mqtt_host": {"required": True, "config": "gateway_mqtt_host"}, + "mqtt_port": {"required": False, "config": "gateway_mqtt_port", "default": 8883}, + "mqtt_token": {"required": True, "config": "gateway_mqtt_token"}, + "mqtt_refresh_token": {"required": False, "config": "gateway_mqtt_refresh_token"}, + }, + "phase": 1, + "can_restart": True, + }, } From c51b3637d9c79073a26c1c458570b548559118a5 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 07:33:18 +0000 Subject: [PATCH 06/11] feat(gateway): implement JWT token refresh with 1-hour pre-expiry renewal Add extract_jwt_expiry and token_needs_refresh static methods. Wire into _check_token_refresh to extract expiry from JWT claims directly. Co-Authored-By: Claude Opus 4.6 --- apps/predbat/gateway.py | 41 +++++++++++++++++++++++++++++- apps/predbat/tests/test_gateway.py | 36 ++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/apps/predbat/gateway.py b/apps/predbat/gateway.py index 1ad1f1863..d902b429f 100644 --- a/apps/predbat/gateway.py +++ b/apps/predbat/gateway.py @@ -417,7 +417,11 @@ async def _check_token_refresh(self): if not self.mqtt_refresh_token: return - if self.mqtt_token_expires_at and time.time() < (self.mqtt_token_expires_at - _TOKEN_REFRESH_THRESHOLD): + # Extract expiry from JWT if not yet known + if not self.mqtt_token_expires_at and self.mqtt_token: + self.mqtt_token_expires_at = self.extract_jwt_expiry(self.mqtt_token) + + if self.mqtt_token_expires_at and not self.token_needs_refresh(self.mqtt_token_expires_at): return if self._refresh_in_progress: @@ -541,6 +545,41 @@ def build_execution_plan(entries, plan_version, timezone): return plan.SerializeToString() + @staticmethod + def extract_jwt_expiry(jwt_token): + """Extract the exp claim from a JWT without verifying signature. + + Args: + jwt_token: JWT string (header.payload.signature). + + Returns: + Unix timestamp of expiry, or 0 if parsing fails. + """ + import base64 + + try: + parts = jwt_token.split(".") + if len(parts) != 3: + return 0 + # Add padding + payload_b64 = parts[1] + "=" * (4 - len(parts[1]) % 4) + payload = json.loads(base64.urlsafe_b64decode(payload_b64)) + return payload.get("exp", 0) + except Exception: + return 0 + + @staticmethod + def token_needs_refresh(exp_epoch): + """Check if token should be refreshed (1 hour before expiry). + + Args: + exp_epoch: Unix timestamp of token expiry. + + Returns: + True if token expires within 1 hour. + """ + return (exp_epoch - int(time.time())) < 3600 + @staticmethod def build_command(command, **kwargs): """Build JSON command string for ad-hoc control. diff --git a/apps/predbat/tests/test_gateway.py b/apps/predbat/tests/test_gateway.py index 70ebb84a4..e395a0459 100644 --- a/apps/predbat/tests/test_gateway.py +++ b/apps/predbat/tests/test_gateway.py @@ -203,3 +203,39 @@ def test_set_reserve_command(self): parsed = json.loads(cmd) assert parsed["command"] == "set_reserve" assert parsed["target_soc"] == 10 + + +class TestTokenRefresh: + def test_jwt_expiry_extraction(self): + """Extract exp claim from a JWT without verification.""" + from gateway import GatewayMQTT + import base64 + import json as json_mod + + # Build a fake JWT with exp claim + header = base64.urlsafe_b64encode(json_mod.dumps({"alg": "RS256"}).encode()).rstrip(b"=") + payload = base64.urlsafe_b64encode(json_mod.dumps({"exp": 1741789200, "sub": "test"}).encode()).rstrip(b"=") + fake_jwt = f"{header.decode()}.{payload.decode()}.fake_signature" + + exp = GatewayMQTT.extract_jwt_expiry(fake_jwt) + assert exp == 1741789200 + + def test_jwt_expiry_invalid_token(self): + """Invalid JWT returns 0.""" + from gateway import GatewayMQTT + + assert GatewayMQTT.extract_jwt_expiry("not-a-jwt") == 0 + assert GatewayMQTT.extract_jwt_expiry("") == 0 + + def test_token_needs_refresh(self): + """Token should be refreshed 1 hour before expiry.""" + from gateway import GatewayMQTT + import time as time_mod + + # Token expiring in 30 minutes — needs refresh + exp_soon = int(time_mod.time()) + 1800 + assert GatewayMQTT.token_needs_refresh(exp_soon) is True + + # Token expiring in 2 hours — does not need refresh + exp_later = int(time_mod.time()) + 7200 + assert GatewayMQTT.token_needs_refresh(exp_later) is False From b039559a941220035e5834e05117cb3054b83fe9 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 07:35:04 +0000 Subject: [PATCH 07/11] feat(gateway): add EMS multi-inverter entity mapping with per-sub-inverter entities Co-Authored-By: Claude Opus 4.6 --- apps/predbat/gateway.py | 18 ++++++++++++ apps/predbat/tests/test_gateway.py | 45 ++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/apps/predbat/gateway.py b/apps/predbat/gateway.py index d902b429f..d20812a9e 100644 --- a/apps/predbat/gateway.py +++ b/apps/predbat/gateway.py @@ -511,6 +511,24 @@ def decode_telemetry(data): if obj is not None: entities[entity_name] = obj + # EMS aggregate entities (when type is GIVENERGY_EMS) + if inv.type == pb.INVERTER_TYPE_GIVENERGY_EMS and inv.ems.num_inverters > 0: + entities["predbat_gateway_ems_total_soc"] = inv.ems.total_soc + entities["predbat_gateway_ems_total_charge"] = inv.ems.total_charge_w + entities["predbat_gateway_ems_total_discharge"] = inv.ems.total_discharge_w + entities["predbat_gateway_ems_total_grid"] = inv.ems.total_grid_w + entities["predbat_gateway_ems_total_pv"] = inv.ems.total_pv_w + entities["predbat_gateway_ems_total_load"] = inv.ems.total_load_w + + # Per-sub-inverter entities + for idx, sub in enumerate(inv.ems.sub_inverters): + prefix = f"predbat_gateway_sub{idx}" + entities[f"{prefix}_soc"] = sub.soc + entities[f"{prefix}_battery_power"] = sub.battery_w + entities[f"{prefix}_pv_power"] = sub.pv_w + entities[f"{prefix}_grid_power"] = sub.grid_w + entities[f"{prefix}_temp"] = sub.temp_c + return entities @staticmethod diff --git a/apps/predbat/tests/test_gateway.py b/apps/predbat/tests/test_gateway.py index e395a0459..a2f5a6a26 100644 --- a/apps/predbat/tests/test_gateway.py +++ b/apps/predbat/tests/test_gateway.py @@ -205,6 +205,51 @@ def test_set_reserve_command(self): assert parsed["target_soc"] == 10 +class TestEMSEntities: + def test_ems_aggregate_entities(self): + """EMS type produces aggregate entities.""" + status = pb.GatewayStatus() + status.device_id = "pbgw_ems_test" + status.timestamp = 1741789200 + status.schema_version = 1 + status.dongle_count = 1 + + inv = status.inverters.add() + inv.type = pb.INVERTER_TYPE_GIVENERGY_EMS + inv.serial = "EM1234" + inv.connected = True + inv.active = True + + inv.ems.num_inverters = 2 + inv.ems.total_soc = 60 + inv.ems.total_charge_w = 3000 + inv.ems.total_pv_w = 5000 + inv.ems.total_grid_w = -1000 + inv.ems.total_load_w = 4000 + + sub0 = inv.ems.sub_inverters.add() + sub0.soc = 55 + sub0.battery_w = 1500 + sub0.pv_w = 2500 + sub1 = inv.ems.sub_inverters.add() + sub1.soc = 65 + sub1.battery_w = 1500 + sub1.pv_w = 2500 + + from gateway import GatewayMQTT + + entities = GatewayMQTT.decode_telemetry(status.SerializeToString()) + + # EMS aggregate entities + assert entities.get("predbat_gateway_ems_total_soc") == 60 + assert entities.get("predbat_gateway_ems_total_pv") == 5000 + assert entities.get("predbat_gateway_ems_total_load") == 4000 + # Per-sub-inverter + assert entities.get("predbat_gateway_sub0_soc") == 55 + assert entities.get("predbat_gateway_sub1_soc") == 65 + assert entities.get("predbat_gateway_sub0_battery_power") == 1500 + + class TestTokenRefresh: def test_jwt_expiry_extraction(self): """Extract exp claim from a JWT without verification.""" From e32460eea6ab42717f6a8824685e2f9628a33062 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 07:38:57 +0000 Subject: [PATCH 08/11] test(gateway): add MQTT integration test for plan publish format Co-Authored-By: Claude Opus 4.6 --- apps/predbat/tests/test_gateway.py | 43 ++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/apps/predbat/tests/test_gateway.py b/apps/predbat/tests/test_gateway.py index a2f5a6a26..94285abdb 100644 --- a/apps/predbat/tests/test_gateway.py +++ b/apps/predbat/tests/test_gateway.py @@ -7,6 +7,10 @@ from proto import gateway_status_pb2 as pb +import importlib.util + +HAS_AIOMQTT = importlib.util.find_spec("aiomqtt") is not None + class TestProtobufDecode: """Test protobuf telemetry → entity mapping.""" @@ -284,3 +288,42 @@ def test_token_needs_refresh(self): # Token expiring in 2 hours — does not need refresh exp_later = int(time_mod.time()) + 7200 assert GatewayMQTT.token_needs_refresh(exp_later) is False + + +class TestMQTTIntegration: + """Integration tests for MQTT plan publishing format.""" + + @pytest.mark.skipif(not HAS_AIOMQTT, reason="aiomqtt not installed") + def test_plan_publish_format(self): + """Plan published to /schedule topic is valid protobuf.""" + from gateway import GatewayMQTT + + entries = [ + { + "enabled": True, + "start_hour": 1, + "start_minute": 30, + "end_hour": 4, + "end_minute": 30, + "mode": 1, + "power_w": 3000, + "target_soc": 100, + "days_of_week": 0x7F, + "use_native": True, + } + ] + + data = GatewayMQTT.build_execution_plan(entries, plan_version=1, timezone="Europe/London") + + # Verify the protobuf is valid and can be decoded + plan = pb.ExecutionPlan() + plan.ParseFromString(data) + assert plan.entries[0].start_hour == 1 + assert plan.entries[0].use_native is True + assert plan.timezone == "Europe/London" + + # Verify plan_version is monotonically increasing + data2 = GatewayMQTT.build_execution_plan(entries, plan_version=2, timezone="Europe/London") + plan2 = pb.ExecutionPlan() + plan2.ParseFromString(data2) + assert plan2.plan_version > plan.plan_version From 5910c3c06650c73be13645f4ee578cc83c4bbedd Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 07:43:20 +0000 Subject: [PATCH 09/11] fix(gateway): address spec compliance review findings - Fix MQTT topic prefix: gw/ -> predbat/devices/ (critical) - Add retain=True on schedule topic publish - Set api_started=True on first telemetry decode - Add get_error_count() with error tracking - Fix edge function name: refresh-mqtt-token (matching spec) - Remove unused _TOKEN_REFRESH_THRESHOLD constant Co-Authored-By: Claude Opus 4.6 --- apps/predbat/gateway.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/apps/predbat/gateway.py b/apps/predbat/gateway.py index d20812a9e..6a674a8b5 100644 --- a/apps/predbat/gateway.py +++ b/apps/predbat/gateway.py @@ -63,9 +63,6 @@ "schedule.discharge_end": "predbat_gateway_discharge_end", } -# Token refresh threshold — refresh when less than 2 hours remaining -_TOKEN_REFRESH_THRESHOLD = 2 * 60 * 60 - # Plan re-publish interval (seconds) _PLAN_REPUBLISH_INTERVAL = 5 * 60 @@ -99,7 +96,7 @@ def initialize(self, gateway_device_id=None, mqtt_host=None, mqtt_port=8883, mqt self.mqtt_token_expires_at = 0 # MQTT topic strings - self._topic_base = f"gw/{gateway_device_id}" if gateway_device_id else "gw/unknown" + self._topic_base = f"predbat/devices/{gateway_device_id}" if gateway_device_id else "predbat/devices/unknown" self.topic_status = f"{self._topic_base}/status" self.topic_online = f"{self._topic_base}/online" self.topic_schedule = f"{self._topic_base}/schedule" @@ -115,6 +112,7 @@ def initialize(self, gateway_device_id=None, mqtt_host=None, mqtt_port=8883, mqt self._last_plan_publish_time = 0 self._plan_version = 0 self._refresh_in_progress = False + self._error_count = 0 async def run(self, seconds, first): """Component run loop — called every 60 seconds by ComponentBase.start(). @@ -161,7 +159,7 @@ async def run(self, seconds, first): if self._last_plan_data and self._mqtt_connected: elapsed = time.time() - self._last_plan_publish_time if elapsed > _PLAN_REPUBLISH_INTERVAL: - await self._publish_raw(self.topic_schedule, self._last_plan_data) + await self._publish_raw(self.topic_schedule, self._last_plan_data, retain=True) self._last_plan_publish_time = time.time() self.log("Info: GatewayMQTT: Re-published execution plan (stale)") @@ -214,6 +212,7 @@ async def _mqtt_loop(self): self.log("Info: GatewayMQTT: MQTT loop cancelled") break except Exception as e: + self._error_count += 1 self.log(f"Warn: GatewayMQTT: MQTT connection error: {e}") self._mqtt_connected = False self._mqtt_client = None @@ -252,6 +251,7 @@ async def _handle_message(self, message): attributes={"friendly_name": "Gateway Online"}, ) except Exception as e: + self._error_count += 1 self.log(f"Warn: GatewayMQTT: Error handling message on {topic}: {e}") self.log(f"Warn: {traceback.format_exc()}") @@ -268,6 +268,10 @@ def _process_telemetry(self, data): self._last_telemetry_time = time.time() self.update_success_timestamp() + if not self.api_started: + self.api_started = True + self.log("Info: GatewayMQTT: First telemetry received, API started") + for entity_name, value in entities.items(): self.set_state_wrapper( f"sensor.{self.prefix}{entity_name}", @@ -288,7 +292,7 @@ async def publish_plan(self, plan_entries, timezone_str): self._last_plan_publish_time = time.time() if self._mqtt_connected: - await self._publish_raw(self.topic_schedule, data) + await self._publish_raw(self.topic_schedule, data, retain=True) self.log(f"Info: GatewayMQTT: Published execution plan v{self._plan_version} ({len(plan_entries)} entries)") else: self.log("Warn: GatewayMQTT: Not connected — plan queued for next publish") @@ -308,15 +312,16 @@ async def publish_command(self, command, **kwargs): else: self.log(f"Warn: GatewayMQTT: Not connected — cannot publish command: {command}") - async def _publish_raw(self, topic, payload): + async def _publish_raw(self, topic, payload, retain=False): """Publish raw bytes to an MQTT topic. Args: topic: MQTT topic string. payload: Bytes to publish. + retain: Whether to set the retain flag. """ if self._mqtt_client and self._mqtt_connected: - await self._mqtt_client.publish(topic, payload, qos=1) + await self._mqtt_client.publish(topic, payload, qos=1, retain=retain) def is_alive(self): """Check if the gateway component is alive and receiving data. @@ -341,6 +346,10 @@ def is_alive(self): return (time.time() - self._last_telemetry_time) < _TELEMETRY_STALE_THRESHOLD + def get_error_count(self): + """Return the cumulative error count (decode failures, MQTT disconnects, publish failures).""" + return self._error_count + async def select_event(self, entity_id, value): """Handle select entity changes (e.g. mode selection). @@ -437,7 +446,7 @@ async def _check_token_refresh(self): self.log("Warn: GatewayMQTT: Token refresh skipped — missing env vars or instance_id") return - url = f"{supabase_url}/functions/v1/gateway-token-refresh" + url = f"{supabase_url}/functions/v1/refresh-mqtt-token" headers = { "Authorization": f"Bearer {supabase_key}", "Content-Type": "application/json", From c76e60dd47ef6b0204fbcc14c5b03c23757086c9 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 07:47:09 +0000 Subject: [PATCH 10/11] fix(gateway): use oauth-refresh endpoint with provider predbat_gateway MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consistent with Fox OAuth pattern — single oauth-refresh edge function handles all providers including gateway MQTT token refresh. Co-Authored-By: Claude Opus 4.6 --- apps/predbat/gateway.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/predbat/gateway.py b/apps/predbat/gateway.py index 6a674a8b5..bf234010d 100644 --- a/apps/predbat/gateway.py +++ b/apps/predbat/gateway.py @@ -446,14 +446,14 @@ async def _check_token_refresh(self): self.log("Warn: GatewayMQTT: Token refresh skipped — missing env vars or instance_id") return - url = f"{supabase_url}/functions/v1/refresh-mqtt-token" + url = f"{supabase_url}/functions/v1/oauth-refresh" headers = { "Authorization": f"Bearer {supabase_key}", "Content-Type": "application/json", } payload = { "instance_id": instance_id, - "refresh_token": self.mqtt_refresh_token, + "provider": "predbat_gateway", } self.log("Info: GatewayMQTT: Refreshing MQTT token") From 05c9ebdef3482f6318a33bbbb7e5fbc935a8c7a0 Mon Sep 17 00:00:00 2001 From: Mark Gascoyne Date: Fri, 13 Mar 2026 07:52:40 +0000 Subject: [PATCH 11/11] =?UTF-8?q?refactor(gateway):=20remove=20mqtt=5Frefr?= =?UTF-8?q?esh=5Ftoken=20=E2=80=94=20server=20handles=20refresh?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit oauth-refresh edge function reads refresh token from instance secrets, so the component doesn't need to hold or pass it. Consistent with how Fox OAuth works via OAuthMixin. Co-Authored-By: Claude Opus 4.6 --- apps/predbat/components.py | 1 - apps/predbat/gateway.py | 14 ++++---------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/apps/predbat/components.py b/apps/predbat/components.py index a58cd82f9..4818705f0 100644 --- a/apps/predbat/components.py +++ b/apps/predbat/components.py @@ -317,7 +317,6 @@ "mqtt_host": {"required": True, "config": "gateway_mqtt_host"}, "mqtt_port": {"required": False, "config": "gateway_mqtt_port", "default": 8883}, "mqtt_token": {"required": True, "config": "gateway_mqtt_token"}, - "mqtt_refresh_token": {"required": False, "config": "gateway_mqtt_refresh_token"}, }, "phase": 1, "can_restart": True, diff --git a/apps/predbat/gateway.py b/apps/predbat/gateway.py index bf234010d..88bd73621 100644 --- a/apps/predbat/gateway.py +++ b/apps/predbat/gateway.py @@ -77,7 +77,7 @@ class GatewayMQTT(ComponentBase): Instance methods handle MQTT lifecycle and ComponentBase integration. """ - def initialize(self, gateway_device_id=None, mqtt_host=None, mqtt_port=8883, mqtt_token=None, mqtt_refresh_token=None, **kwargs): + def initialize(self, gateway_device_id=None, mqtt_host=None, mqtt_port=8883, mqtt_token=None, **kwargs): """Initialize gateway configuration and build MQTT topic strings. Args: @@ -85,14 +85,12 @@ def initialize(self, gateway_device_id=None, mqtt_host=None, mqtt_port=8883, mqt mqtt_host: MQTT broker hostname. mqtt_port: MQTT broker port (default 8883 for TLS). mqtt_token: JWT access token for MQTT authentication. - mqtt_refresh_token: Refresh token for token renewal. **kwargs: Additional keyword arguments (ignored). """ self.gateway_device_id = gateway_device_id self.mqtt_host = mqtt_host self.mqtt_port = mqtt_port self.mqtt_token = mqtt_token - self.mqtt_refresh_token = mqtt_refresh_token self.mqtt_token_expires_at = 0 # MQTT topic strings @@ -420,12 +418,10 @@ async def final(self): async def _check_token_refresh(self): """Check if the MQTT JWT token needs refreshing and refresh if needed. - Uses the Supabase edge function (same pattern as OAuthMixin) to - obtain a new access token before the current one expires. + Uses the oauth-refresh edge function (same pattern as OAuthMixin) to + obtain a new access token before the current one expires. The refresh + token is held server-side in instance secrets. """ - if not self.mqtt_refresh_token: - return - # Extract expiry from JWT if not yet known if not self.mqtt_token_expires_at and self.mqtt_token: self.mqtt_token_expires_at = self.extract_jwt_expiry(self.mqtt_token) @@ -469,8 +465,6 @@ async def _check_token_refresh(self): if data.get("success"): self.mqtt_token = data["access_token"] - if data.get("refresh_token"): - self.mqtt_refresh_token = data["refresh_token"] if data.get("expires_at"): try: if isinstance(data["expires_at"], (int, float)):