diff --git a/src/winml/modelkit/commands/perf.py b/src/winml/modelkit/commands/perf.py index 5f6141114..670ed7742 100644 --- a/src/winml/modelkit/commands/perf.py +++ b/src/winml/modelkit/commands/perf.py @@ -26,6 +26,7 @@ from rich.console import Console from rich.table import Table +from ..session.monitor.memory_tracker import MemoryProfile from ..utils import cli as cli_utils from ..utils.constants import EPName, EPNameOrAlias from ..utils.logging import configure_logging @@ -82,6 +83,7 @@ class BenchmarkConfig: skip_build: bool = True allow_unsupported_nodes: bool = False monitor: bool = False + memory: bool = True ep: EPNameOrAlias | None = None shape_config: dict | None = None @@ -129,6 +131,9 @@ class BenchmarkResult: # Hardware monitor metrics (from HWMonitor.to_dict()) hw_monitor: dict[str, Any] | None = None + # Memory profile (from MemoryTracker) + memory_profile: MemoryProfile | None = None + def to_dict(self) -> dict[str, Any]: """Convert to dictionary for JSON serialization.""" result = { @@ -169,6 +174,8 @@ def to_dict(self) -> dict[str, Any]: } if self.hw_monitor: result["hw_monitor"] = self.hw_monitor + if self.memory_profile: + result["memory"] = self.memory_profile.to_dict() return result @@ -281,6 +288,7 @@ def __init__(self, config: BenchmarkConfig) -> None: self.config = config self._model: WinMLPreTrainedModel | None = None self._inputs: dict[str, np.ndarray] | None = None + self._memory_tracker: Any = None def run(self) -> BenchmarkResult: """Execute full benchmark pipeline. @@ -288,11 +296,21 @@ def run(self) -> BenchmarkResult: Returns: BenchmarkResult with timing statistics """ + # Initialize memory tracker if enabled + if self.config.memory: + from ..session.monitor.memory_tracker import MemoryTracker + + self._memory_tracker = MemoryTracker() + self._memory_tracker.snapshot_baseline() + # [1] Load model logger.info("Loading model: %s", self.config.model_id) self._load_model() assert self._model is not None + if self._memory_tracker: + self._memory_tracker.snapshot_post_load() + # [2] Generate inputs logger.info("Generating benchmark inputs") self._generate_inputs() @@ -300,6 +318,10 @@ def run(self) -> BenchmarkResult: # Compile session early so model.device is resolved for display self._model._session.compile() + if self._memory_tracker: + adapter_luid = self._resolve_adapter_luid() + self._memory_tracker.snapshot_post_compile(adapter_luid=adapter_luid) + # Print model info before benchmark starts _print_model_info( self._model.io_config, @@ -317,6 +339,10 @@ def run(self) -> BenchmarkResult: ) stats = self._run_benchmark() + if self._memory_tracker: + adapter_luid = self._resolve_adapter_luid() + self._memory_tracker.snapshot_post_inference(adapter_luid=adapter_luid) + # [4] Collect results logger.info("Collecting results") return self._collect_results(stats) @@ -384,6 +410,40 @@ def _generate_inputs(self) -> None: batch_size=self.config.batch_size, ) + def _resolve_adapter_luid(self) -> str | None: + """Resolve the adapter LUID for device memory queries. + + Uses the same resolution logic as HWMonitor: device kind + EP name. + Returns None on non-Windows or when no adapter is available. + """ + import sys + + if sys.platform != "win32": + return None + + assert self._model is not None + device = self._model.device or self.config.device + ep_name = self._model.ep_name + + if device == "cpu": + return None + + try: + from ..sysinfo.pdh_adapters import resolve_adapter_luid + + if device == "npu": + return resolve_adapter_luid("npu", ep_name=ep_name) + if device == "gpu": + return resolve_adapter_luid("gpu", ep_name=ep_name) + # "auto" — try NPU first, then GPU + luid = resolve_adapter_luid("npu", ep_name=ep_name) + if luid: + return luid + return resolve_adapter_luid("gpu", ep_name=ep_name) + except Exception: + logger.debug("Could not resolve adapter LUID for memory query", exc_info=True) + return None + def _run_benchmark(self) -> PerfStats: """Execute benchmark iterations with timing.""" if self.config.monitor: @@ -517,6 +577,8 @@ def _collect_results(self, stats: PerfStats) -> BenchmarkResult: actual_ep=self._model.ep_name, # Hardware monitor metrics (only present when --monitor is used) hw_monitor=getattr(self, "_hw_metrics", None), + # Memory profile (only present when --memory is used) + memory_profile=(self._memory_tracker.profile() if self._memory_tracker else None), ) @@ -874,6 +936,15 @@ def display_console_report(result: BenchmarkResult, console: Console) -> None: f" CPU: {cpu.get('mean_pct', 0):.1f}% avg | Mem: {ram.get('used_mb', 0):.0f} MB" ) + # Memory section (only when --memory is enabled) + if result.memory_profile: + mem = result.memory_profile + inference_ws = mem.post_inference.working_set_mb + inference_dev = mem.post_inference.device_local_mb + dev_str = f" | {inference_dev:.1f} MB (device)" if inference_dev > 0 else "" + console.print() + console.print(f"[bold]Memory:[/bold] {inference_ws:.1f} MB (process){dev_str}") + console.print() @@ -1103,6 +1174,12 @@ def _run_simple_loop( show_default=True, help="Show live hardware utilization chart for the benchmarked device (NPU, GPU, or CPU)", ) +@click.option( + "--memory/--no-memory", + default=True, + show_default=True, + help="Measure process and device memory at each benchmark phase", +) @click.option( "--op-tracing", "op_tracing", @@ -1134,6 +1211,7 @@ def perf( allow_unsupported_nodes: bool, module_class: str | None, monitor: bool, + memory: bool, op_tracing: str | None, output_format: cli_utils.OutputFormat, verbose: int, @@ -1272,6 +1350,7 @@ def perf( skip_build=skip_build, allow_unsupported_nodes=allow_unsupported_nodes, monitor=monitor, + memory=memory, ep=ep, shape_config=shape_config, ) diff --git a/src/winml/modelkit/session/monitor/memory_tracker.py b/src/winml/modelkit/session/monitor/memory_tracker.py new file mode 100644 index 000000000..24e33198e --- /dev/null +++ b/src/winml/modelkit/session/monitor/memory_tracker.py @@ -0,0 +1,357 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Process memory tracking for perf benchmarking. + +Provides lightweight, zero-dependency process memory snapshots via Windows +``GetProcessMemoryInfo`` (ctypes). Used by ``winml perf --memory`` to measure +memory consumption at each benchmark phase. + +For device (NPU/GPU) memory, a single-shot PDH query is used to read +``\GPU Process Memory\Local Usage`` and ``\GPU Process Memory\Shared Usage``. +""" + +from __future__ import annotations + +import ctypes +import logging +import os +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, ClassVar + + +if sys.platform == "win32": + import ctypes.wintypes as wintypes + + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Process Memory via GetProcessMemoryInfo (Windows) +# ============================================================================= + +_MB = 1024 * 1024 + + +if sys.platform == "win32": + + class _ProcessMemoryCountersEx(ctypes.Structure): + """PROCESS_MEMORY_COUNTERS_EX structure from psapi.h.""" + + _fields_: ClassVar = [ + ("cb", wintypes.DWORD), + ("PageFaultCount", wintypes.DWORD), + ("PeakWorkingSetSize", ctypes.c_size_t), + ("WorkingSetSize", ctypes.c_size_t), + ("QuotaPeakPagedPoolUsage", ctypes.c_size_t), + ("QuotaPagedPoolUsage", ctypes.c_size_t), + ("QuotaPeakNonPagedPoolUsage", ctypes.c_size_t), + ("QuotaNonPagedPoolUsage", ctypes.c_size_t), + ("PagefileUsage", ctypes.c_size_t), + ("PeakPagefileUsage", ctypes.c_size_t), + ("PrivateUsage", ctypes.c_size_t), + ] + + +def _get_process_memory() -> tuple[float, float, float, float]: + """Get current process memory via K32GetProcessMemoryInfo. + + Uses kernel32.K32GetProcessMemoryInfo (Windows 7+) which supports + PROCESS_MEMORY_COUNTERS_EX natively. + + Returns: + (working_set_mb, peak_working_set_mb, private_bytes_mb, peak_private_bytes_mb) + """ + if sys.platform != "win32": + return _get_process_memory_linux() + + kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) + kernel32.GetCurrentProcess.restype = wintypes.HANDLE + kernel32.K32GetProcessMemoryInfo.restype = wintypes.BOOL + kernel32.K32GetProcessMemoryInfo.argtypes = [ + wintypes.HANDLE, + ctypes.POINTER(_ProcessMemoryCountersEx), + wintypes.DWORD, + ] + + handle = kernel32.GetCurrentProcess() + counters = _ProcessMemoryCountersEx() + counters.cb = ctypes.sizeof(counters) + + success = kernel32.K32GetProcessMemoryInfo(handle, ctypes.byref(counters), counters.cb) + if not success: + err = ctypes.get_last_error() + logger.warning("K32GetProcessMemoryInfo failed (error=%d), returning zeros", err) + return (0.0, 0.0, 0.0, 0.0) + + return ( + counters.WorkingSetSize / _MB, + counters.PeakWorkingSetSize / _MB, + counters.PrivateUsage / _MB, + counters.PeakPagefileUsage / _MB, + ) + + +def _get_process_memory_linux() -> tuple[float, float, float, float]: + """Fallback for Linux: read /proc/self/status.""" + try: + with Path("/proc/self/status").open() as f: + content = f.read() + + values: dict[str, float] = {} + for line in content.splitlines(): + parts = line.split() + if len(parts) >= 2 and parts[0].rstrip(":") in ( + "VmRSS", + "VmPeak", + "VmSize", + ): + values[parts[0].rstrip(":")] = float(parts[1]) / 1024 # kB -> MB + + rss = values.get("VmRSS", 0.0) + peak = values.get("VmPeak", 0.0) + return (rss, peak, rss, peak) + except OSError: + return (0.0, 0.0, 0.0, 0.0) + + +# ============================================================================= +# Device Memory via single-shot PDH query +# ============================================================================= + + +def _get_device_memory_mb(luid: str | None) -> tuple[float, float]: + """Single-shot PDH query for device memory (local, shared) in MB. + + Args: + luid: Adapter LUID string. If None, returns (0, 0). + + Returns: + (local_mb, shared_mb) + """ + if luid is None or sys.platform != "win32": + return (0.0, 0.0) + + try: + from ._pdh import PdhQuery + + pid = os.getpid() + query = PdhQuery() + query.open() + + local_ok = query.add_counter( + "local", + rf"\GPU Process Memory(pid_{pid}_luid_{luid}_phys_0)\Local Usage", + fmt="large", + ) + shared_ok = query.add_counter( + "shared", + rf"\GPU Process Memory(pid_{pid}_luid_{luid}_phys_0)\Shared Usage", + fmt="large", + ) + + if not local_ok and not shared_ok: + query.close() + return (0.0, 0.0) + + # PDH large counters don't need priming (not rate-based) + query.prime() + values = query.collect() + query.close() + + local_bytes = values.get("local") or 0 + shared_bytes = values.get("shared") or 0 + return (local_bytes / _MB, shared_bytes / _MB) + except Exception: + logger.debug("Device memory query failed", exc_info=True) + return (0.0, 0.0) + + +# ============================================================================= +# Data Classes +# ============================================================================= + + +@dataclass +class MemorySnapshot: + """A point-in-time memory measurement.""" + + working_set_mb: float = 0.0 + peak_working_set_mb: float = 0.0 + private_bytes_mb: float = 0.0 + peak_private_bytes_mb: float = 0.0 + device_local_mb: float = 0.0 + device_shared_mb: float = 0.0 + + def to_dict(self) -> dict[str, float]: + """JSON-serializable dictionary.""" + return { + "working_set_mb": round(self.working_set_mb, 2), + "peak_working_set_mb": round(self.peak_working_set_mb, 2), + "private_bytes_mb": round(self.private_bytes_mb, 2), + "peak_private_bytes_mb": round(self.peak_private_bytes_mb, 2), + "device_local_mb": round(self.device_local_mb, 2), + "device_shared_mb": round(self.device_shared_mb, 2), + } + + +@dataclass +class MemoryProfile: + """Memory measurements across benchmark phases.""" + + baseline: MemorySnapshot + post_load: MemorySnapshot + post_compile: MemorySnapshot + post_inference: MemorySnapshot + + @property + def load_delta_mb(self) -> float: + """Working set increase from model loading.""" + return self.post_load.working_set_mb - self.baseline.working_set_mb + + @property + def compile_delta_mb(self) -> float: + """Working set increase from session compilation.""" + return self.post_compile.working_set_mb - self.post_load.working_set_mb + + @property + def inference_delta_mb(self) -> float: + """Working set increase during inference.""" + return self.post_inference.working_set_mb - self.post_compile.working_set_mb + + @property + def total_delta_mb(self) -> float: + """Total working set increase from baseline.""" + return self.post_inference.working_set_mb - self.baseline.working_set_mb + + @property + def peak_working_set_mb(self) -> float: + """Peak working set across all phases (from OS counter).""" + return self.post_inference.peak_working_set_mb + + @property + def peak_device_local_mb(self) -> float: + """Peak device local memory across all phases.""" + return max( + self.baseline.device_local_mb, + self.post_load.device_local_mb, + self.post_compile.device_local_mb, + self.post_inference.device_local_mb, + ) + + @property + def peak_device_shared_mb(self) -> float: + """Peak device shared memory across all phases.""" + return max( + self.baseline.device_shared_mb, + self.post_load.device_shared_mb, + self.post_compile.device_shared_mb, + self.post_inference.device_shared_mb, + ) + + def to_dict(self) -> dict[str, Any]: + """JSON-serializable dictionary.""" + return { + "baseline": self.baseline.to_dict(), + "post_load": self.post_load.to_dict(), + "post_compile": self.post_compile.to_dict(), + "post_inference": self.post_inference.to_dict(), + "peak_working_set_mb": round(self.peak_working_set_mb, 2), + "peak_device_local_mb": round(self.peak_device_local_mb, 2), + "peak_device_shared_mb": round(self.peak_device_shared_mb, 2), + "total_delta_working_set_mb": round(self.total_delta_mb, 2), + } + + +# ============================================================================= +# MemoryTracker +# ============================================================================= + + +class MemoryTracker: + """Lightweight memory tracker that takes snapshots at phase boundaries. + + Usage:: + + tracker = MemoryTracker() + tracker.snapshot_baseline() + # ... load model ... + tracker.snapshot_post_load() + # ... compile ... + tracker.snapshot_post_compile(adapter_luid="0x...") + # ... run benchmark ... + tracker.snapshot_post_inference(adapter_luid="0x...") + profile = tracker.profile() + """ + + def __init__(self) -> None: + self._baseline: MemorySnapshot | None = None + self._post_load: MemorySnapshot | None = None + self._post_compile: MemorySnapshot | None = None + self._post_inference: MemorySnapshot | None = None + + def _take_snapshot(self, adapter_luid: str | None = None) -> MemorySnapshot: + """Take a point-in-time memory snapshot.""" + ws, peak_ws, priv, peak_priv = _get_process_memory() + dev_local, dev_shared = _get_device_memory_mb(adapter_luid) + return MemorySnapshot( + working_set_mb=ws, + peak_working_set_mb=peak_ws, + private_bytes_mb=priv, + peak_private_bytes_mb=peak_priv, + device_local_mb=dev_local, + device_shared_mb=dev_shared, + ) + + def snapshot_baseline(self) -> None: + """Capture baseline memory before model loading.""" + self._baseline = self._take_snapshot() + + def snapshot_post_load(self) -> None: + """Capture memory after model loading.""" + self._post_load = self._take_snapshot() + + def snapshot_post_compile(self, adapter_luid: str | None = None) -> None: + """Capture memory after session compilation. + + Args: + adapter_luid: Adapter LUID for device memory query. + Available after compile resolves the EP. + """ + self._post_compile = self._take_snapshot(adapter_luid) + + def snapshot_post_inference(self, adapter_luid: str | None = None) -> None: + """Capture memory after benchmark completion. + + Args: + adapter_luid: Adapter LUID for device memory query. + """ + self._post_inference = self._take_snapshot(adapter_luid) + + def profile(self) -> MemoryProfile | None: + """Build a complete MemoryProfile from collected snapshots. + + Returns None if any phase snapshot is missing. + """ + if any( + s is None + for s in (self._baseline, self._post_load, self._post_compile, self._post_inference) + ): + logger.warning("Incomplete memory snapshots, cannot build profile") + return None + + assert self._baseline is not None + assert self._post_load is not None + assert self._post_compile is not None + assert self._post_inference is not None + + return MemoryProfile( + baseline=self._baseline, + post_load=self._post_load, + post_compile=self._post_compile, + post_inference=self._post_inference, + ) diff --git a/tests/unit/commands/test_perf_cli.py b/tests/unit/commands/test_perf_cli.py index fa4541493..a83bdc843 100644 --- a/tests/unit/commands/test_perf_cli.py +++ b/tests/unit/commands/test_perf_cli.py @@ -489,6 +489,7 @@ def test_format_text_shows_console_report( mock_result.samples_per_sec = 100.0 mock_result.batches_per_sec = 100.0 mock_result.hw_monitor = None + mock_result.memory_profile = None mock_instance = MagicMock() mock_instance.run.return_value = mock_result mock_benchmark_class.return_value = mock_instance diff --git a/tests/unit/session/monitor/test_memory_tracker.py b/tests/unit/session/monitor/test_memory_tracker.py new file mode 100644 index 000000000..a1a7c4a2c --- /dev/null +++ b/tests/unit/session/monitor/test_memory_tracker.py @@ -0,0 +1,167 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for the memory_tracker module.""" + +from __future__ import annotations + +import pytest + +from winml.modelkit.session.monitor.memory_tracker import ( + MemoryProfile, + MemorySnapshot, + MemoryTracker, + _get_process_memory, +) + + +class TestGetProcessMemory: + """Test the process memory retrieval function.""" + + def test_returns_four_floats(self) -> None: + result = _get_process_memory() + assert len(result) == 4 + for val in result: + assert isinstance(val, float) + + def test_working_set_positive(self) -> None: + ws, peak_ws, priv, peak_priv = _get_process_memory() + # Our process should be using *some* memory + assert ws > 0 + assert peak_ws >= ws + assert priv > 0 + assert peak_priv >= priv + + +class TestMemorySnapshot: + """Test MemorySnapshot dataclass.""" + + def test_to_dict(self) -> None: + snap = MemorySnapshot( + working_set_mb=100.123, + peak_working_set_mb=120.456, + private_bytes_mb=80.789, + peak_private_bytes_mb=90.012, + device_local_mb=50.347, + device_shared_mb=10.678, + ) + d = snap.to_dict() + assert d["working_set_mb"] == 100.12 + assert d["peak_working_set_mb"] == 120.46 + assert d["private_bytes_mb"] == 80.79 + assert d["device_local_mb"] == 50.35 + assert d["device_shared_mb"] == 10.68 + + def test_defaults_are_zero(self) -> None: + snap = MemorySnapshot() + assert snap.working_set_mb == 0.0 + assert snap.device_local_mb == 0.0 + + +class TestMemoryProfile: + """Test MemoryProfile computed properties.""" + + @pytest.fixture + def profile(self) -> MemoryProfile: + return MemoryProfile( + baseline=MemorySnapshot( + working_set_mb=100.0, + peak_working_set_mb=100.0, + private_bytes_mb=120.0, + peak_private_bytes_mb=120.0, + ), + post_load=MemorySnapshot( + working_set_mb=300.0, + peak_working_set_mb=310.0, + private_bytes_mb=350.0, + peak_private_bytes_mb=350.0, + ), + post_compile=MemorySnapshot( + working_set_mb=320.0, + peak_working_set_mb=325.0, + private_bytes_mb=370.0, + peak_private_bytes_mb=375.0, + device_local_mb=50.0, + ), + post_inference=MemorySnapshot( + working_set_mb=330.0, + peak_working_set_mb=340.0, + private_bytes_mb=380.0, + peak_private_bytes_mb=385.0, + device_local_mb=52.0, + device_shared_mb=8.0, + ), + ) + + def test_load_delta(self, profile: MemoryProfile) -> None: + assert profile.load_delta_mb == pytest.approx(200.0) + + def test_compile_delta(self, profile: MemoryProfile) -> None: + assert profile.compile_delta_mb == pytest.approx(20.0) + + def test_inference_delta(self, profile: MemoryProfile) -> None: + assert profile.inference_delta_mb == pytest.approx(10.0) + + def test_total_delta(self, profile: MemoryProfile) -> None: + assert profile.total_delta_mb == pytest.approx(230.0) + + def test_peak_working_set(self, profile: MemoryProfile) -> None: + assert profile.peak_working_set_mb == pytest.approx(340.0) + + def test_peak_device_local(self, profile: MemoryProfile) -> None: + assert profile.peak_device_local_mb == pytest.approx(52.0) + + def test_peak_device_shared(self, profile: MemoryProfile) -> None: + assert profile.peak_device_shared_mb == pytest.approx(8.0) + + def test_to_dict(self, profile: MemoryProfile) -> None: + d = profile.to_dict() + assert "baseline" in d + assert "post_load" in d + assert "post_compile" in d + assert "post_inference" in d + assert d["peak_working_set_mb"] == 340.0 + assert d["total_delta_working_set_mb"] == 230.0 + + +class TestMemoryTracker: + """Test MemoryTracker snapshot collection.""" + + def test_full_workflow(self) -> None: + tracker = MemoryTracker() + tracker.snapshot_baseline() + tracker.snapshot_post_load() + tracker.snapshot_post_compile() + tracker.snapshot_post_inference() + profile = tracker.profile() + + assert profile is not None + assert profile.baseline.working_set_mb > 0 + assert profile.post_inference.working_set_mb > 0 + + def test_incomplete_returns_none(self) -> None: + tracker = MemoryTracker() + tracker.snapshot_baseline() + # Missing other phases + profile = tracker.profile() + assert profile is None + + def test_snapshots_are_nondecreasing(self) -> None: + """Working set should generally not decrease between adjacent snapshots.""" + tracker = MemoryTracker() + tracker.snapshot_baseline() + + # Allocate something to ensure memory grows + _data = [bytearray(1024 * 1024) for _ in range(5)] # ~5 MB + + tracker.snapshot_post_load() + tracker.snapshot_post_compile() + tracker.snapshot_post_inference() + profile = tracker.profile() + + assert profile is not None + # post_load should be >= baseline (we allocated memory) + assert profile.post_load.working_set_mb >= profile.baseline.working_set_mb + # Keep _data alive until assertions complete so memory isn't reclaimed early + assert _data is not None