diff --git a/qubes/qmemman/client.py b/qubes/qmemman/client.py index 8e0e28bac..c8430eee8 100644 --- a/qubes/qmemman/client.py +++ b/qubes/qmemman/client.py @@ -19,23 +19,31 @@ import socket import fcntl -from typing import Optional class QMemmanClient: def __init__(self) -> None: - self.sock: Optional[socket.socket] = None + self.sock: socket.socket | None = None - def request_mem(self, amount) -> bool: + def _send(self, data: str) -> bool: self.sock = socket.socket(socket.AF_UNIX) flags = fcntl.fcntl(self.sock.fileno(), fcntl.F_GETFD) flags |= fcntl.FD_CLOEXEC fcntl.fcntl(self.sock.fileno(), fcntl.F_SETFD, flags) self.sock.connect("/var/run/qubes/qmemman.sock") - self.sock.send(str(int(amount)).encode("ascii") + b"\n") + self.sock.send(data.encode("ascii")) received = self.sock.recv(1024).strip() return bool(received == b"OK") + def request_mem(self, amount: int | float) -> bool: + return self._send("{}\n".format(int(amount))) + + def set_mem(self, dom_memset: dict[int | str, int | float]) -> bool: + dom_memset_str = " ".join( + "{}:{}".format(key, value) for key, value in dom_memset.items() + ) + return self._send("{}\n".format(dom_memset_str)) + def close(self) -> None: assert isinstance(self.sock, socket.socket) self.sock.close() diff --git a/qubes/qmemman/systemstate.py b/qubes/qmemman/systemstate.py index d85e17e1f..60acd8065 100644 --- a/qubes/qmemman/systemstate.py +++ b/qubes/qmemman/systemstate.py @@ -40,6 +40,11 @@ CHECK_MB_S = 100 MIN_TOTAL_MEMORY_TRANSFER = 150 * 1024 * 1024 MIN_MEM_CHANGE_WHEN_UNDER_PREF = 15 * 1024 * 1024 +#: number of loop iterations for CHECK_PERIOD_S seconds +CHECK_PERIOD = max(1, int((CHECK_PERIOD_S + 0.0) / BALLOON_DELAY)) +#: number of free memory bytes expected to get during CHECK_PERIOD_S +#: seconds +CHECK_DELTA = CHECK_PERIOD_S * CHECK_MB_S * 1024 * 1024 class SystemState: @@ -110,11 +115,13 @@ def get_free_xen_mem(self) -> int: ) return xen_free - assigned_but_unused - # Refresh information on memory assigned to all domains - def refresh_mem_actual(self) -> None: + # Refresh information on memory assigned to all or specific domains + def refresh_mem_actual(self, domid_list: Optional[list] = None) -> None: for domain in self.xc.domain_getinfo(): domid = str(domain["domid"]) if domid in self.dom_dict: + if domid_list and domid not in domid_list: + continue dom = self.dom_dict[domid] # Real memory usage dom.mem_current = domain["mem_kb"] * 1024 @@ -216,17 +223,12 @@ def do_balloon(self, mem_size) -> bool: for dom in self.dom_dict.values(): dom.no_progress = False - #: number of loop iterations for CHECK_PERIOD_S seconds - check_period = max(1, int((CHECK_PERIOD_S + 0.0) / BALLOON_DELAY)) - #: number of free memory bytes expected to get during CHECK_PERIOD_S - #: seconds - check_delta = CHECK_PERIOD_S * CHECK_MB_S * 1024 * 1024 #: helper array for holding free memory size, CHECK_PERIOD_S seconds #: ago, at every loop iteration - xenfree_ring = [0] * check_period + xenfree_ring = [0] * CHECK_PERIOD while True: - self.log.debug("niter={:2d}".format(niter)) + self.log.debug("niter={:d}".format(niter)) self.refresh_mem_actual() xenfree = self.get_free_xen_mem() self.log.info("xenfree={!r}".format(xenfree)) @@ -235,10 +237,10 @@ def do_balloon(self, mem_size) -> bool: return True # fail the request if over past CHECK_PERIOD_S seconds, # we got less than CHECK_MB_S MB/s on average - ring_slot = niter % check_period + ring_slot = niter % CHECK_PERIOD if ( - niter >= check_period - and xenfree < xenfree_ring[ring_slot] + check_delta + niter >= CHECK_PERIOD + and xenfree < xenfree_ring[ring_slot] + CHECK_DELTA ): return False xenfree_ring[ring_slot] = xenfree @@ -265,6 +267,56 @@ def do_balloon(self, mem_size) -> bool: time.sleep(BALLOON_DELAY) niter = niter + 1 + def do_balloon_dom(self, dom_memset: dict) -> bool: + self.log.info("do_balloon_dom(dom_memset={!r})".format(dom_memset)) + niter = 0 + if not dom_memset: + return False + + domid_list = list(dom_memset.keys()) + dom_dict = { + domid: state + for domid, state in self.dom_dict.items() + if domid in domid_list + } + + for _, dom in dom_dict.items(): + dom.no_progress = False + + memset_reqs = {} + for domid, memset in dom_memset.items(): + if memset == 0: + mem_pref = qubes.qmemman.algo.pref_mem(dom_dict[domid]) + memset_reqs[domid] = mem_pref + self.log.debug( + "mem for dom '%s' is 0, using its pref '%s'", + domid, + mem_pref, + ) + else: + memset_reqs[domid] = memset + + succeeded = [] + while True: + self.log.debug("niter={:d}".format(niter)) + self.refresh_mem_actual(domid_list) + for domid, dom in dom_dict.items(): + assert isinstance(dom.mem_actual, int) + if ( + domid not in succeeded + and dom.mem_actual / memset_reqs[domid] < 1.1 + ): + succeeded.append(domid) + if all(dom in succeeded for dom in domid_list): + return True + if niter >= CHECK_PERIOD: + return False + for domid, memset in memset_reqs.items(): + self.mem_set(domid, memset) + self.log.debug("sleeping for {} s".format(BALLOON_DELAY)) + time.sleep(BALLOON_DELAY) + niter += 1 + def refresh_meminfo(self, domid, untrusted_meminfo_key) -> None: self.log.debug( "refresh_meminfo(domid={}, untrusted_meminfo_key={!r})".format( diff --git a/qubes/tools/qmemmand.py b/qubes/tools/qmemmand.py index 0a7e505b5..89361eac5 100644 --- a/qubes/tools/qmemmand.py +++ b/qubes/tools/qmemmand.py @@ -191,6 +191,7 @@ def handle(self): # self.request is the TCP socket connected to the client while True: self.data = self.request.recv(1024).strip() + data_args = self.data.decode("ascii").split() self.log.debug("data=%r", self.data) if len(self.data) == 0: self.log.info("client disconnected, resuming membalance") @@ -210,12 +211,25 @@ def handle(self): self.log.debug("GLOBAL_LOCK acquired") got_lock = True - if self.data.isdigit() and system_state.do_balloon( - int(self.data.decode("ascii")) - ): - resp = b"OK\n" - else: - resp = b"FAIL\n" + + resp = "INVALID_ARG" + if self.data.isdigit(): + resp = "FAIL" + memory = int(data_args[0]) + if system_state.do_balloon(memory): + resp = "OK" + elif ":" in data_args[0]: + resp = "FAIL" + dom_memset = { + str(key): int(value) + for key, value in ( + pair.split(":") for pair in data_args + ) + } + if system_state.do_balloon_dom(dom_memset): + resp = "OK" + resp = str(resp + "\n").encode("ascii") + self.log.debug("resp={!r}".format(resp)) self.request.send(resp) except BaseException as e: diff --git a/qubes/vm/dispvm.py b/qubes/vm/dispvm.py index 3f162e72d..d7978e9bd 100644 --- a/qubes/vm/dispvm.py +++ b/qubes/vm/dispvm.py @@ -370,6 +370,25 @@ async def on_domain_started_dispvm( self.app.save() self.preload_complete.set() + @qubes.events.handler("domain-pre-paused") + async def on_domain_pre_paused( + self, event, **kwargs + ): # pylint: disable=unused-argument + if not self.is_preload or self.maxmem == 0: + return + qmemman_client = None + try: + qmemman_client = await asyncio.get_event_loop().run_in_executor( + None, self.set_mem + ) + except Exception as exc: + self.log.warning( + "Preload memory request before pause failed: %s", str(exc) + ) + if qmemman_client: + qmemman_client.close() + raise + @qubes.events.handler("domain-paused") def on_domain_paused( self, event, **kwargs diff --git a/qubes/vm/qubesvm.py b/qubes/vm/qubesvm.py index 6dc5be0ee..432111331 100644 --- a/qubes/vm/qubesvm.py +++ b/qubes/vm/qubesvm.py @@ -2076,6 +2076,25 @@ def request_mem(self, mem_required=None): return qmemman_client + def set_mem(self): + """ + Balloon to mem_pref. + """ + if not qmemman_present or self.maxmem == 0: + return None + + qmemman_client = qubes.qmemman.client.QMemmanClient() + try: + result = qmemman_client.set_mem({self.xid: 0}) + except IOError as e: + raise IOError("Failed to connect to qmemman: {!s}".format(e)) + + if not result: + qmemman_client.close() + self.log.warning("Failed to set memory") + + return qmemman_client + @staticmethod async def start_daemon(*command, input=None, **kwargs): """Start a daemon for the VM