From 6a02b68c0def107f5a4594ee3c78f7bb6e30be98 Mon Sep 17 00:00:00 2001 From: Harsh Rawat Date: Fri, 8 May 2026 01:27:08 +0530 Subject: [PATCH 1/2] [shimV2] lazy-initialize SCSI controller and introduce internal interfaces Deferred SCSI controller setup from VM creation to first use, making SCSIController context-aware and fallible. Introduced internal utilityVM and guestManager interfaces on vm.Controller to decouple from concrete vmmanager/guestmanager types and enable mocking. Propagated the new signature through the LCOW pod controller and updated/regenerated the affected mocks and tests. Signed-off-by: Harsh Rawat --- internal/controller/pod/mocks/mock_types.go | 11 +++-- internal/controller/pod/pod_lcow.go | 7 ++- internal/controller/pod/pod_lcow_test.go | 14 +++++- internal/controller/pod/types_lcow.go | 2 +- internal/controller/vm/types.go | 51 +++++++++++++++++++++ internal/controller/vm/vm.go | 20 ++++---- internal/controller/vm/vm_devices.go | 22 +++++++-- internal/controller/vm/vm_lcow.go | 19 +++++--- internal/controller/vm/vm_wcow.go | 7 +-- internal/vm/guestmanager/guest.go | 11 +++-- internal/vm/vmmanager/utils.go | 11 +---- 11 files changed, 127 insertions(+), 48 deletions(-) diff --git a/internal/controller/pod/mocks/mock_types.go b/internal/controller/pod/mocks/mock_types.go index 5059ad096f..d6d32bd26d 100644 --- a/internal/controller/pod/mocks/mock_types.go +++ b/internal/controller/pod/mocks/mock_types.go @@ -104,17 +104,18 @@ func (mr *MockvmControllerMockRecorder) RuntimeID() *gomock.Call { } // SCSIController mocks base method. -func (m *MockvmController) SCSIController() *scsi.Controller { +func (m *MockvmController) SCSIController(ctx context.Context) (*scsi.Controller, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SCSIController") + ret := m.ctrl.Call(m, "SCSIController", ctx) ret0, _ := ret[0].(*scsi.Controller) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // SCSIController indicates an expected call of SCSIController. -func (mr *MockvmControllerMockRecorder) SCSIController() *gomock.Call { +func (mr *MockvmControllerMockRecorder) SCSIController(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SCSIController", reflect.TypeOf((*MockvmController)(nil).SCSIController)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SCSIController", reflect.TypeOf((*MockvmController)(nil).SCSIController), ctx) } // VPCIController mocks base method. diff --git a/internal/controller/pod/pod_lcow.go b/internal/controller/pod/pod_lcow.go index d0b7a1840c..bf00337cf7 100644 --- a/internal/controller/pod/pod_lcow.go +++ b/internal/controller/pod/pod_lcow.go @@ -89,12 +89,17 @@ func (c *Controller) NewContainer(ctx context.Context, containerID string) (*lin return nil, fmt.Errorf("container %q already exists in pod %q", containerID, c.podID) } + scsiCtrl, err := c.vm.SCSIController(ctx) + if err != nil { + return nil, fmt.Errorf("get SCSI controller for pod %s: %w", c.podID, err) + } + containerCtrl := linuxcontainer.New( c.vm.RuntimeID(), c.gcsPodID, containerID, c.vm.Guest(), - c.vm.SCSIController(), + scsiCtrl, c.vm.Plan9Controller(), c.vm.VPCIController(), ) diff --git a/internal/controller/pod/pod_lcow_test.go b/internal/controller/pod/pod_lcow_test.go index 5e28a6ad8a..dbf3b0f20c 100644 --- a/internal/controller/pod/pod_lcow_test.go +++ b/internal/controller/pod/pod_lcow_test.go @@ -42,7 +42,7 @@ func newSetup(t *testing.T) (*mocks.MockvmController, *mocks.MocknetworkControll func expectVMCallsForNewContainer(vm *mocks.MockvmController) { vm.EXPECT().RuntimeID().Return("vm-runtime-1") vm.EXPECT().Guest().Return(nil) - vm.EXPECT().SCSIController().Return(nil) + vm.EXPECT().SCSIController(gomock.Any()).Return(nil, nil) vm.EXPECT().Plan9Controller().Return(nil) vm.EXPECT().VPCIController().Return(nil) } @@ -191,6 +191,18 @@ func TestNewContainer(t *testing.T) { } } }) + + t.Run("scsi controller error", func(t *testing.T) { + vm, _, c := newSetup(t) + vm.EXPECT().SCSIController(gomock.Any()).Return(nil, errTest) + + if _, err := c.NewContainer(t.Context(), "container-scsi-fail"); !errors.Is(err, errTest) { + t.Fatalf("NewContainer error = %v, want %v", err, errTest) + } + if _, ok := c.containers["container-scsi-fail"]; ok { + t.Error("container should not be registered when SCSIController fails") + } + }) } // TestListContainers verifies snapshots of the live container map. diff --git a/internal/controller/pod/types_lcow.go b/internal/controller/pod/types_lcow.go index a2455839ba..503ae9687b 100644 --- a/internal/controller/pod/types_lcow.go +++ b/internal/controller/pod/types_lcow.go @@ -23,7 +23,7 @@ type vmController interface { Guest() *guestmanager.Guest // SCSIController returns the SCSI device controller for the VM. - SCSIController() *scsi.Controller + SCSIController(ctx context.Context) (*scsi.Controller, error) // VPCIController returns the vPCI device controller for the VM. VPCIController() *vpci.Controller diff --git a/internal/controller/vm/types.go b/internal/controller/vm/types.go index 9312bedaae..e7c1e7cdc8 100644 --- a/internal/controller/vm/types.go +++ b/internal/controller/vm/types.go @@ -3,15 +3,32 @@ package vm import ( + "context" + "net" "time" + "github.com/Microsoft/go-winio" runhcsoptions "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" + "github.com/Microsoft/hcsshim/internal/cmd" + "github.com/Microsoft/hcsshim/internal/gcs" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/internal/vm/guestmanager" + "github.com/Microsoft/hcsshim/internal/vm/vmmanager" vmsandbox "github.com/Microsoft/hcsshim/sandbox-spec/vm/v2" "github.com/Microsoft/go-winio/pkg/guid" ) +// Package-level indirection for HCS/winio entry points used by [Controller], +// allowing tests to swap in fakes without standing up a real VM. +var ( + // listenHVSock opens a host-side hvsock listener. + listenHVSock = winio.ListenHvsock + // createVM creates the underlying utility VM via HCS. + createVM = vmmanager.Create +) + // CreateOptions contains the configuration needed to create a new VM. type CreateOptions struct { // ID specifies the unique identifier for the VM. @@ -49,3 +66,37 @@ type ExitStatus struct { // This will be nil if the VM exited cleanly. Err error } + +// utilityVM is the subset of the underlying [vmmanager.UtilityVM] surface that +// the [Controller] depends on. It is defined as an interface so that the +// Controller can be unit tested without starting up a real HCS VM. +type utilityVM interface { + ID() string + RuntimeID() guid.GUID + Start(ctx context.Context) error + AcceptConnection(ctx context.Context, l net.Listener, closeConnection bool) (net.Conn, error) + Wait(ctx context.Context) error + Terminate(ctx context.Context) error + Close(ctx context.Context) error + SetCPUGroup(ctx context.Context, settings *hcsschema.CpuGroup) error + UpdateCPULimits(ctx context.Context, settings *hcsschema.ProcessorLimits) error + UpdateMemory(ctx context.Context, memory uint64) error + PropertiesV2(ctx context.Context, types ...hcsschema.PropertyType) (*hcsschema.Properties, error) + StartedTime() time.Time + StoppedTime() time.Time + ExitError() error +} + +// guestManager is the subset of [guestmanager.Guest] that the [Controller] +// depends on. It is defined as an interface so that the Controller can be +// unit tested without a live guest connection. +type guestManager interface { + PrepareConnection(GCSServiceID guid.GUID) error + CreateConnection(ctx context.Context, opts ...guestmanager.ConfigOption) error + CloseConnection() error + AddSecurityPolicy(ctx context.Context, opts guestresource.ConfidentialOptions) error + InjectPolicyFragment(ctx context.Context, fragment guestresource.SecurityPolicyFragment) error + Capabilities() gcs.GuestDefinedCapabilities + DumpStacks(ctx context.Context) (string, error) + ExecIntoUVM(ctx context.Context, request *cmd.CmdProcessRequest) (int, error) +} diff --git a/internal/controller/vm/vm.go b/internal/controller/vm/vm.go index d623ef2d7f..875da4be73 100644 --- a/internal/controller/vm/vm.go +++ b/internal/controller/vm/vm.go @@ -21,7 +21,6 @@ import ( "github.com/Microsoft/hcsshim/internal/shimdiag" "github.com/Microsoft/hcsshim/internal/timeout" "github.com/Microsoft/hcsshim/internal/vm/guestmanager" - "github.com/Microsoft/hcsshim/internal/vm/vmmanager" "github.com/Microsoft/hcsshim/internal/vm/vmutils" iwin "github.com/Microsoft/hcsshim/internal/windows" @@ -36,8 +35,8 @@ import ( // and its associated resources. type Controller struct { vmID string - uvm *vmmanager.UtilityVM - guest *guestmanager.Guest + uvm utilityVM + guest guestManager // vmState tracks the current state of the VM lifecycle. // Access must be guarded by mu. @@ -46,6 +45,9 @@ type Controller struct { // mu guards the concurrent access to the Controller's fields and operations. mu sync.RWMutex + // hcsDocument is the HCS compute system document used to create the VM. + hcsDocument *hcsschema.ComputeSystem + // logOutputDone is closed when the GCS log output processing goroutine completes. logOutputDone chan struct{} @@ -80,7 +82,7 @@ func New() *Controller { // Guest returns the guest manager instance for this VM. // The guest manager provides access to guest-host communication. func (c *Controller) Guest() *guestmanager.Guest { - return c.guest + return c.guest.(*guestmanager.Guest) } // State returns the current VM state. @@ -122,7 +124,7 @@ func (c *Controller) CreateVM(ctx context.Context, opts *CreateOptions) error { } // Create the VM via vmmanager. - uvm, err := vmmanager.Create(ctx, opts.ID, hcsDocument) + uvm, err := createVM(ctx, opts.ID, hcsDocument) if err != nil { return fmt.Errorf("failed to create VM: %w", err) } @@ -130,18 +132,12 @@ func (c *Controller) CreateVM(ctx context.Context, opts *CreateOptions) error { // Set the Controller parameters after successful creation. c.vmID = opts.ID c.uvm = uvm + c.hcsDocument = hcsDocument // Initialize the GuestManager for managing guest interactions. // We will create the guest connection via GuestManager during StartVM. c.guest = guestmanager.New(ctx, uvm) - // Eager initialize the SCSI controller as opposed to all other controllers. - // This is because we always use SCSI for attaching scratch VHDs. - c.scsiController, err = newSCSIController(ctx, hcsDocument, c.uvm, c.guest) - if err != nil { - return fmt.Errorf("failed to initialize SCSI controller: %w", err) - } - c.vmState = StateCreated return nil } diff --git a/internal/controller/vm/vm_devices.go b/internal/controller/vm/vm_devices.go index e85044a956..ff5e6a1463 100644 --- a/internal/controller/vm/vm_devices.go +++ b/internal/controller/vm/vm_devices.go @@ -11,11 +11,25 @@ import ( "github.com/Microsoft/hcsshim/internal/controller/device/vpci" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" + "github.com/Microsoft/hcsshim/internal/vm/guestmanager" + "github.com/Microsoft/hcsshim/internal/vm/vmmanager" ) // SCSIController returns the singleton SCSI device controller for this VM. -func (c *Controller) SCSIController() *scsi.Controller { - return c.scsiController +func (c *Controller) SCSIController(ctx context.Context) (*scsi.Controller, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.scsiController == nil { + uvm := c.uvm.(*vmmanager.UtilityVM) + guest := c.guest.(*guestmanager.Guest) + ctrl, err := newSCSIController(ctx, c.hcsDocument, uvm, guest) + if err != nil { + return nil, fmt.Errorf("failed to initialize SCSI controller: %w", err) + } + c.scsiController = ctrl + } + return c.scsiController, nil } // VPCIController returns the singleton vPCI device controller for this VM. @@ -24,7 +38,9 @@ func (c *Controller) VPCIController() *vpci.Controller { defer c.mu.Unlock() if c.vpciController == nil { - c.vpciController = vpci.New(c.uvm, c.guest) + uvm := c.uvm.(*vmmanager.UtilityVM) + guest := c.guest.(*guestmanager.Guest) + c.vpciController = vpci.New(uvm, guest) } return c.vpciController diff --git a/internal/controller/vm/vm_lcow.go b/internal/controller/vm/vm_lcow.go index 0775ff1377..40aac30edd 100644 --- a/internal/controller/vm/vm_lcow.go +++ b/internal/controller/vm/vm_lcow.go @@ -13,6 +13,7 @@ import ( "github.com/Microsoft/hcsshim/internal/controller/network" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/vm/guestmanager" "github.com/Microsoft/hcsshim/internal/vm/vmmanager" "github.com/Microsoft/hcsshim/internal/vm/vmutils" @@ -38,7 +39,7 @@ func (c *Controller) SandboxOptions() *lcow.SandboxOptions { return c.sandboxOptions } -// buildConfig builds the HCS document for an LCOW VM by calling lcow.BuildSandboxConfig. +// buildHCSConfig builds the HCS document for an LCOW VM by calling lcow.BuildSandboxConfig. // It also stores the sandbox options within the controller. func (c *Controller) buildHCSConfig(ctx context.Context, opts *CreateOptions) (*hcsschema.ComputeSystem, error) { hcsDocument, sandboxOptions, err := lcow.BuildSandboxConfig(ctx, opts.Owner, opts.BundlePath, opts.ShimOpts, opts.SandboxSpec) @@ -82,10 +83,12 @@ func (c *Controller) NetworkController(networkNamespaceID string) *network.Contr policyBasedRouting = c.sandboxOptions.PolicyBasedRouting } + uvm := c.uvm.(*vmmanager.UtilityVM) + guest := c.guest.(*guestmanager.Guest) return network.New(&network.Options{ NetworkNamespace: networkNamespaceID, PolicyBasedRouting: policyBasedRouting, - }, c.uvm, c.guest, c.guest) + }, uvm, guest, guest) } // Plan9Controller returns the singleton controller which can be used @@ -100,7 +103,9 @@ func (c *Controller) Plan9Controller() *plan9.Controller { noWritableShares = c.sandboxOptions.NoWritableFileShares } - c.plan9Controller = plan9.New(c.uvm, c.guest, noWritableShares) + uvm := c.uvm.(*vmmanager.UtilityVM) + guest := c.guest.(*guestmanager.Guest) + c.plan9Controller = plan9.New(uvm, guest, noWritableShares) } return c.plan9Controller @@ -113,7 +118,7 @@ func (c *Controller) Plan9Controller() *plan9.Controller { // random data to the Linux init process when it connects. func (c *Controller) setupEntropyListener(ctx context.Context, group *errgroup.Group) error { // The Linux guest will connect to this port during init to receive entropy. - entropyConn, err := winio.ListenHvsock(&winio.HvsockAddr{ + entropyConn, err := listenHVSock(&winio.HvsockAddr{ VMID: c.uvm.RuntimeID(), ServiceID: winio.VsockServiceID(vmutils.LinuxEntropyVsockPort), }) @@ -128,7 +133,7 @@ func (c *Controller) setupEntropyListener(ctx context.Context, group *errgroup.G // must be done in a goroutine since, when using the internal bridge, the // call to Start() will block until the GCS launches, and this cannot occur // until the host accepts and closes the entropy connection. - conn, err := vmmanager.AcceptConnection(ctx, c.uvm, entropyConn, true) + conn, err := c.uvm.AcceptConnection(ctx, entropyConn, true) if err != nil { return fmt.Errorf("failed to accept connection on hvSocket for entropy: %w", err) } @@ -154,7 +159,7 @@ func (c *Controller) setupEntropyListener(ctx context.Context, group *errgroup.G // forwarded to the host's logging system for monitoring and debugging. func (c *Controller) setupLoggingListener(ctx context.Context, group *errgroup.Group) error { // The GCS will connect to this port to stream log output. - logConn, err := winio.ListenHvsock(&winio.HvsockAddr{ + logConn, err := listenHVSock(&winio.HvsockAddr{ VMID: c.uvm.RuntimeID(), ServiceID: winio.VsockServiceID(vmutils.LinuxLogVsockPort), }) @@ -167,7 +172,7 @@ func (c *Controller) setupLoggingListener(ctx context.Context, group *errgroup.G defer logConn.Close() // Accept the connection from the GCS. - conn, err := vmmanager.AcceptConnection(ctx, c.uvm, logConn, true) + conn, err := c.uvm.AcceptConnection(ctx, logConn, true) if err != nil { close(c.logOutputDone) return fmt.Errorf("failed to accept connection on hvSocket for logs: %w", err) diff --git a/internal/controller/vm/vm_wcow.go b/internal/controller/vm/vm_wcow.go index d9a57b4b61..42303b82a0 100644 --- a/internal/controller/vm/vm_wcow.go +++ b/internal/controller/vm/vm_wcow.go @@ -11,7 +11,7 @@ import ( "github.com/Microsoft/hcsshim/internal/gcs/prot" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" - "github.com/Microsoft/hcsshim/internal/vm/vmmanager" + "github.com/Microsoft/hcsshim/internal/vm/guestmanager" "github.com/Microsoft/hcsshim/internal/vm/vmutils" "github.com/sirupsen/logrus" @@ -78,7 +78,7 @@ func (c *Controller) setupLoggingListener(ctx context.Context, _ *errgroup.Group for { // Accept a connection from the GCS. - conn, err := vmmanager.AcceptConnection(context.WithoutCancel(ctx), c.uvm, limitedListener, false) + conn, err := c.uvm.AcceptConnection(context.WithoutCancel(ctx), limitedListener, false) if err != nil { logrus.WithError(err).Error("failed to connect to log socket") break @@ -123,7 +123,8 @@ func (c *Controller) finalizeGCSConnection(ctx context.Context) error { // Update the guest manager with the HvSocket address configuration. // This enables the GCS to establish proper bidirectional communication. - err := c.guest.UpdateHvSocketAddress(ctx, hvsocketAddress) + guest := c.guest.(*guestmanager.Guest) + err := guest.UpdateHvSocketAddress(ctx, hvsocketAddress) if err != nil { return fmt.Errorf("failed to create GCS connection: %w", err) } diff --git a/internal/vm/guestmanager/guest.go b/internal/vm/guestmanager/guest.go index a95e5b9265..d3d2e328ef 100644 --- a/internal/vm/guestmanager/guest.go +++ b/internal/vm/guestmanager/guest.go @@ -8,13 +8,11 @@ import ( "net" "sync" + "github.com/Microsoft/go-winio" + "github.com/Microsoft/go-winio/pkg/guid" "github.com/Microsoft/hcsshim/internal/gcs" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" - "github.com/Microsoft/hcsshim/internal/vm/vmmanager" - - "github.com/Microsoft/go-winio" - "github.com/Microsoft/go-winio/pkg/guid" "github.com/sirupsen/logrus" ) @@ -25,6 +23,9 @@ type uvm interface { ID() string // RuntimeID returns the Hyper-V VM GUID. RuntimeID() guid.GUID + // AcceptConnection accepts a connection on l, aborting on ctx.Done() or + // VM exit. + AcceptConnection(ctx context.Context, l net.Listener, closeConnection bool) (net.Conn, error) // Wait blocks until the VM exits or ctx is cancelled. Wait(ctx context.Context) error // ExitError returns the error that caused the VM to exit, if any. @@ -116,7 +117,7 @@ func (gm *Guest) CreateConnection(ctx context.Context, opts ...ConfigOption) err l := gm.gcListener gm.gcListener = nil - conn, err := vmmanager.AcceptConnection(ctx, gm.uvm, l, true) + conn, err := gm.uvm.AcceptConnection(ctx, l, true) if err != nil { return fmt.Errorf("failed to connect to GCS: %w", err) } diff --git a/internal/vm/vmmanager/utils.go b/internal/vm/vmmanager/utils.go index f6ac5a04f1..7e1194fb97 100644 --- a/internal/vm/vmmanager/utils.go +++ b/internal/vm/vmmanager/utils.go @@ -7,18 +7,9 @@ import ( "net" ) -// vmWaiter exposes the subset of VM lifecycle needed by [AcceptConnection]: -// Implemented by [UtilityVM]. -type vmWaiter interface { - // Wait blocks until the VM exits or ctx is cancelled. - Wait(ctx context.Context) error - // ExitError returns the error that caused the VM to exit, if any. - ExitError() error -} - // AcceptConnection accepts a connection and then closes a listener. // It monitors ctx.Done() and uvm.Wait() for early termination. -func AcceptConnection(ctx context.Context, uvm vmWaiter, l net.Listener, closeConnection bool) (net.Conn, error) { +func (uvm *UtilityVM) AcceptConnection(ctx context.Context, l net.Listener, closeConnection bool) (net.Conn, error) { // Channel to capture the accept result type acceptResult struct { conn net.Conn From 8b78ebf11f0a93e895d4baf531f6c51eaff2c165 Mon Sep 17 00:00:00 2001 From: Shreyansh Sancheti Date: Wed, 13 May 2026 16:11:18 +0530 Subject: [PATCH 2/2] controller/vm: add function injection seams and unit tests Extends rawahars's interface refactor (utilityVM, guestManager, listenHVSock, createVM) with five additional package-level function variables for test injection: newGuestManager, lookupVMMEM, getProcessMemoryInfo (shared), buildSandboxConfig and parseUVMReferenceInfo (LCOW-only). Also fixes the WCOW setupLoggingListener to use the injected listenHVSock var instead of calling winio.ListenHvsock directly. Adds 62 unit tests (LCOW) / 61 (WCOW) covering: state machine guards, TerminateVM cleanup chain (Close failure to Invalid, CloseConnection error logged not returned, double-terminate idempotency), StartVM error cascade (PrepareConnection, Start, CreateConnection failures each transition to Invalid), waitForVMExit background goroutine race with TerminateVM, ExecIntoHost precondition checks and active-count tracking, DumpStacks capability branching, Wait dual-wait semantics, Stats vmmem lookup and memory reporting (VA-backed vs physically-backed), and concurrent access under the race detector. All tests run without admin or HCS. Signed-off-by: Shreyansh Sancheti --- internal/controller/vm/mocks/mock_types.go | 397 ++++++ internal/controller/vm/platform_lcow_test.go | 5 + internal/controller/vm/platform_wcow_test.go | 5 + internal/controller/vm/types.go | 14 +- internal/controller/vm/vm.go | 7 +- .../controller/vm/vm_createvm_lcow_test.go | 158 +++ internal/controller/vm/vm_lcow.go | 11 +- internal/controller/vm/vm_test.go | 1164 +++++++++++++++++ internal/controller/vm/vm_wcow.go | 2 +- 9 files changed, 1755 insertions(+), 8 deletions(-) create mode 100644 internal/controller/vm/mocks/mock_types.go create mode 100644 internal/controller/vm/platform_lcow_test.go create mode 100644 internal/controller/vm/platform_wcow_test.go create mode 100644 internal/controller/vm/vm_createvm_lcow_test.go create mode 100644 internal/controller/vm/vm_test.go diff --git a/internal/controller/vm/mocks/mock_types.go b/internal/controller/vm/mocks/mock_types.go new file mode 100644 index 0000000000..6e4a668b9d --- /dev/null +++ b/internal/controller/vm/mocks/mock_types.go @@ -0,0 +1,397 @@ +//go:build windows && (lcow || wcow) + +// Code generated by MockGen. DO NOT EDIT. +// Source: internal/controller/vm/types.go +// +// Generated by this command: +// +// mockgen -build_flags=-tags=windows,lcow -build_constraint=windows && (lcow || wcow) -source internal/controller/vm/types.go -package mocks -destination internal/controller/vm/mocks/mock_types.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + net "net" + reflect "reflect" + time "time" + + guid "github.com/Microsoft/go-winio/pkg/guid" + cmd "github.com/Microsoft/hcsshim/internal/cmd" + gcs "github.com/Microsoft/hcsshim/internal/gcs" + schema2 "github.com/Microsoft/hcsshim/internal/hcs/schema2" + guestresource "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + guestmanager "github.com/Microsoft/hcsshim/internal/vm/guestmanager" + gomock "go.uber.org/mock/gomock" +) + +// MockutilityVM is a mock of utilityVM interface. +type MockutilityVM struct { + ctrl *gomock.Controller + recorder *MockutilityVMMockRecorder + isgomock struct{} +} + +// MockutilityVMMockRecorder is the mock recorder for MockutilityVM. +type MockutilityVMMockRecorder struct { + mock *MockutilityVM +} + +// NewMockutilityVM creates a new mock instance. +func NewMockutilityVM(ctrl *gomock.Controller) *MockutilityVM { + mock := &MockutilityVM{ctrl: ctrl} + mock.recorder = &MockutilityVMMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockutilityVM) EXPECT() *MockutilityVMMockRecorder { + return m.recorder +} + +// AcceptConnection mocks base method. +func (m *MockutilityVM) AcceptConnection(ctx context.Context, l net.Listener, closeConnection bool) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptConnection", ctx, l, closeConnection) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptConnection indicates an expected call of AcceptConnection. +func (mr *MockutilityVMMockRecorder) AcceptConnection(ctx, l, closeConnection any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptConnection", reflect.TypeOf((*MockutilityVM)(nil).AcceptConnection), ctx, l, closeConnection) +} + +// Close mocks base method. +func (m *MockutilityVM) Close(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockutilityVMMockRecorder) Close(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockutilityVM)(nil).Close), ctx) +} + +// ExitError mocks base method. +func (m *MockutilityVM) ExitError() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExitError") + ret0, _ := ret[0].(error) + return ret0 +} + +// ExitError indicates an expected call of ExitError. +func (mr *MockutilityVMMockRecorder) ExitError() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExitError", reflect.TypeOf((*MockutilityVM)(nil).ExitError)) +} + +// ID mocks base method. +func (m *MockutilityVM) ID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(string) + return ret0 +} + +// ID indicates an expected call of ID. +func (mr *MockutilityVMMockRecorder) ID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockutilityVM)(nil).ID)) +} + +// PropertiesV2 mocks base method. +func (m *MockutilityVM) PropertiesV2(ctx context.Context, types ...schema2.PropertyType) (*schema2.Properties, error) { + m.ctrl.T.Helper() + varargs := []any{ctx} + for _, a := range types { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "PropertiesV2", varargs...) + ret0, _ := ret[0].(*schema2.Properties) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PropertiesV2 indicates an expected call of PropertiesV2. +func (mr *MockutilityVMMockRecorder) PropertiesV2(ctx any, types ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx}, types...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PropertiesV2", reflect.TypeOf((*MockutilityVM)(nil).PropertiesV2), varargs...) +} + +// RuntimeID mocks base method. +func (m *MockutilityVM) RuntimeID() guid.GUID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RuntimeID") + ret0, _ := ret[0].(guid.GUID) + return ret0 +} + +// RuntimeID indicates an expected call of RuntimeID. +func (mr *MockutilityVMMockRecorder) RuntimeID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuntimeID", reflect.TypeOf((*MockutilityVM)(nil).RuntimeID)) +} + +// SetCPUGroup mocks base method. +func (m *MockutilityVM) SetCPUGroup(ctx context.Context, settings *schema2.CpuGroup) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetCPUGroup", ctx, settings) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetCPUGroup indicates an expected call of SetCPUGroup. +func (mr *MockutilityVMMockRecorder) SetCPUGroup(ctx, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCPUGroup", reflect.TypeOf((*MockutilityVM)(nil).SetCPUGroup), ctx, settings) +} + +// Start mocks base method. +func (m *MockutilityVM) Start(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockutilityVMMockRecorder) Start(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockutilityVM)(nil).Start), ctx) +} + +// StartedTime mocks base method. +func (m *MockutilityVM) StartedTime() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartedTime") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// StartedTime indicates an expected call of StartedTime. +func (mr *MockutilityVMMockRecorder) StartedTime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedTime", reflect.TypeOf((*MockutilityVM)(nil).StartedTime)) +} + +// StoppedTime mocks base method. +func (m *MockutilityVM) StoppedTime() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StoppedTime") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// StoppedTime indicates an expected call of StoppedTime. +func (mr *MockutilityVMMockRecorder) StoppedTime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StoppedTime", reflect.TypeOf((*MockutilityVM)(nil).StoppedTime)) +} + +// Terminate mocks base method. +func (m *MockutilityVM) Terminate(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Terminate", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Terminate indicates an expected call of Terminate. +func (mr *MockutilityVMMockRecorder) Terminate(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Terminate", reflect.TypeOf((*MockutilityVM)(nil).Terminate), ctx) +} + +// UpdateCPULimits mocks base method. +func (m *MockutilityVM) UpdateCPULimits(ctx context.Context, settings *schema2.ProcessorLimits) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateCPULimits", ctx, settings) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateCPULimits indicates an expected call of UpdateCPULimits. +func (mr *MockutilityVMMockRecorder) UpdateCPULimits(ctx, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCPULimits", reflect.TypeOf((*MockutilityVM)(nil).UpdateCPULimits), ctx, settings) +} + +// UpdateMemory mocks base method. +func (m *MockutilityVM) UpdateMemory(ctx context.Context, memory uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateMemory", ctx, memory) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateMemory indicates an expected call of UpdateMemory. +func (mr *MockutilityVMMockRecorder) UpdateMemory(ctx, memory any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMemory", reflect.TypeOf((*MockutilityVM)(nil).UpdateMemory), ctx, memory) +} + +// Wait mocks base method. +func (m *MockutilityVM) Wait(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Wait", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Wait indicates an expected call of Wait. +func (mr *MockutilityVMMockRecorder) Wait(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockutilityVM)(nil).Wait), ctx) +} + +// MockguestManager is a mock of guestManager interface. +type MockguestManager struct { + ctrl *gomock.Controller + recorder *MockguestManagerMockRecorder + isgomock struct{} +} + +// MockguestManagerMockRecorder is the mock recorder for MockguestManager. +type MockguestManagerMockRecorder struct { + mock *MockguestManager +} + +// NewMockguestManager creates a new mock instance. +func NewMockguestManager(ctrl *gomock.Controller) *MockguestManager { + mock := &MockguestManager{ctrl: ctrl} + mock.recorder = &MockguestManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockguestManager) EXPECT() *MockguestManagerMockRecorder { + return m.recorder +} + +// AddSecurityPolicy mocks base method. +func (m *MockguestManager) AddSecurityPolicy(ctx context.Context, opts guestresource.ConfidentialOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddSecurityPolicy", ctx, opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddSecurityPolicy indicates an expected call of AddSecurityPolicy. +func (mr *MockguestManagerMockRecorder) AddSecurityPolicy(ctx, opts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSecurityPolicy", reflect.TypeOf((*MockguestManager)(nil).AddSecurityPolicy), ctx, opts) +} + +// Capabilities mocks base method. +func (m *MockguestManager) Capabilities() gcs.GuestDefinedCapabilities { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Capabilities") + ret0, _ := ret[0].(gcs.GuestDefinedCapabilities) + return ret0 +} + +// Capabilities indicates an expected call of Capabilities. +func (mr *MockguestManagerMockRecorder) Capabilities() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Capabilities", reflect.TypeOf((*MockguestManager)(nil).Capabilities)) +} + +// CloseConnection mocks base method. +func (m *MockguestManager) CloseConnection() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseConnection") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseConnection indicates an expected call of CloseConnection. +func (mr *MockguestManagerMockRecorder) CloseConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseConnection", reflect.TypeOf((*MockguestManager)(nil).CloseConnection)) +} + +// CreateConnection mocks base method. +func (m *MockguestManager) CreateConnection(ctx context.Context, opts ...guestmanager.ConfigOption) error { + m.ctrl.T.Helper() + varargs := []any{ctx} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateConnection", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateConnection indicates an expected call of CreateConnection. +func (mr *MockguestManagerMockRecorder) CreateConnection(ctx any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateConnection", reflect.TypeOf((*MockguestManager)(nil).CreateConnection), varargs...) +} + +// DumpStacks mocks base method. +func (m *MockguestManager) DumpStacks(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DumpStacks", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DumpStacks indicates an expected call of DumpStacks. +func (mr *MockguestManagerMockRecorder) DumpStacks(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DumpStacks", reflect.TypeOf((*MockguestManager)(nil).DumpStacks), ctx) +} + +// ExecIntoUVM mocks base method. +func (m *MockguestManager) ExecIntoUVM(ctx context.Context, request *cmd.CmdProcessRequest) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExecIntoUVM", ctx, request) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExecIntoUVM indicates an expected call of ExecIntoUVM. +func (mr *MockguestManagerMockRecorder) ExecIntoUVM(ctx, request any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecIntoUVM", reflect.TypeOf((*MockguestManager)(nil).ExecIntoUVM), ctx, request) +} + +// InjectPolicyFragment mocks base method. +func (m *MockguestManager) InjectPolicyFragment(ctx context.Context, fragment guestresource.SecurityPolicyFragment) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InjectPolicyFragment", ctx, fragment) + ret0, _ := ret[0].(error) + return ret0 +} + +// InjectPolicyFragment indicates an expected call of InjectPolicyFragment. +func (mr *MockguestManagerMockRecorder) InjectPolicyFragment(ctx, fragment any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InjectPolicyFragment", reflect.TypeOf((*MockguestManager)(nil).InjectPolicyFragment), ctx, fragment) +} + +// PrepareConnection mocks base method. +func (m *MockguestManager) PrepareConnection(GCSServiceID guid.GUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PrepareConnection", GCSServiceID) + ret0, _ := ret[0].(error) + return ret0 +} + +// PrepareConnection indicates an expected call of PrepareConnection. +func (mr *MockguestManagerMockRecorder) PrepareConnection(GCSServiceID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareConnection", reflect.TypeOf((*MockguestManager)(nil).PrepareConnection), GCSServiceID) +} diff --git a/internal/controller/vm/platform_lcow_test.go b/internal/controller/vm/platform_lcow_test.go new file mode 100644 index 0000000000..15cca525c8 --- /dev/null +++ b/internal/controller/vm/platform_lcow_test.go @@ -0,0 +1,5 @@ +//go:build windows && lcow + +package vm + +func isLCOW() bool { return true } diff --git a/internal/controller/vm/platform_wcow_test.go b/internal/controller/vm/platform_wcow_test.go new file mode 100644 index 0000000000..3248907e40 --- /dev/null +++ b/internal/controller/vm/platform_wcow_test.go @@ -0,0 +1,5 @@ +//go:build windows && wcow + +package vm + +func isLCOW() bool { return false } diff --git a/internal/controller/vm/types.go b/internal/controller/vm/types.go index e7c1e7cdc8..34915b7795 100644 --- a/internal/controller/vm/types.go +++ b/internal/controller/vm/types.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Microsoft/go-winio" + "github.com/Microsoft/go-winio/pkg/process" runhcsoptions "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" "github.com/Microsoft/hcsshim/internal/cmd" "github.com/Microsoft/hcsshim/internal/gcs" @@ -15,6 +16,7 @@ import ( "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/internal/vm/guestmanager" "github.com/Microsoft/hcsshim/internal/vm/vmmanager" + "github.com/Microsoft/hcsshim/internal/vm/vmutils" vmsandbox "github.com/Microsoft/hcsshim/sandbox-spec/vm/v2" "github.com/Microsoft/go-winio/pkg/guid" @@ -24,9 +26,19 @@ import ( // allowing tests to swap in fakes without standing up a real VM. var ( // listenHVSock opens a host-side hvsock listener. - listenHVSock = winio.ListenHvsock + // The concrete winio.ListenHvsock returns *winio.HvsockListener which + // satisfies net.Listener. We use net.Listener here so tests can inject fakes. + listenHVSock = func(addr *winio.HvsockAddr) (net.Listener, error) { + return winio.ListenHvsock(addr) + } // createVM creates the underlying utility VM via HCS. createVM = vmmanager.Create + // newGuestManager constructs the guest manager for guest-host communication. + newGuestManager = guestmanager.New + // lookupVMMEM finds the vmmem process handle for a given VM. + lookupVMMEM = vmutils.LookupVMMEM + // getProcessMemoryInfo queries memory stats for a process handle. + getProcessMemoryInfo = process.GetProcessMemoryInfo ) // CreateOptions contains the configuration needed to create a new VM. diff --git a/internal/controller/vm/vm.go b/internal/controller/vm/vm.go index 875da4be73..d69d2e0301 100644 --- a/internal/controller/vm/vm.go +++ b/internal/controller/vm/vm.go @@ -24,7 +24,6 @@ import ( "github.com/Microsoft/hcsshim/internal/vm/vmutils" iwin "github.com/Microsoft/hcsshim/internal/windows" - "github.com/Microsoft/go-winio/pkg/process" "github.com/containerd/errdefs" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -136,7 +135,7 @@ func (c *Controller) CreateVM(ctx context.Context, opts *CreateOptions) error { // Initialize the GuestManager for managing guest interactions. // We will create the guest connection via GuestManager during StartVM. - c.guest = guestmanager.New(ctx, uvm) + c.guest = newGuestManager(ctx, uvm) c.vmState = StateCreated return nil @@ -450,7 +449,7 @@ func (c *Controller) Stats(ctx context.Context) (*stats.VirtualMachineStatistics // Initialization of vmmemProcess to calculate stats properly for VA-backed UVMs. if c.vmmemProcess == 0 { - vmmemHandle, err := vmutils.LookupVMMEM(ctx, c.uvm.RuntimeID(), &iwin.WinAPI{}) + vmmemHandle, err := lookupVMMEM(ctx, c.uvm.RuntimeID(), &iwin.WinAPI{}) if err != nil { return nil, fmt.Errorf("cannot get stats: %w", err) } @@ -471,7 +470,7 @@ func (c *Controller) Stats(ctx context.Context) (*stats.VirtualMachineStatistics // working set size for a VA-backed UVM. To work around this, we instead // locate the vmmem process for the VM, and query that process's working set // instead, which will be the working set for the VM. - memCounters, err := process.GetProcessMemoryInfo(c.vmmemProcess) + memCounters, err := getProcessMemoryInfo(c.vmmemProcess) if err != nil { return nil, err } diff --git a/internal/controller/vm/vm_createvm_lcow_test.go b/internal/controller/vm/vm_createvm_lcow_test.go new file mode 100644 index 0000000000..8189141a5d --- /dev/null +++ b/internal/controller/vm/vm_createvm_lcow_test.go @@ -0,0 +1,158 @@ +//go:build windows && lcow + +package vm + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/Microsoft/hcsshim/internal/builder/vm/lcow" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/vm/vmmanager" + + runhcsoptions "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" + vmsandbox "github.com/Microsoft/hcsshim/sandbox-spec/vm/v2" + + "github.com/Microsoft/go-winio" + "go.uber.org/mock/gomock" +) + +// ─── CreateVM Error Paths (LCOW-only) ────────────────────────────────────────── + +func TestCreateVM(t *testing.T) { + ctx := context.Background() + opts := &CreateOptions{ID: "test-vm", Owner: "test-owner"} + + t.Run("buildHCSConfig_fails/state_stays_not_created", func(t *testing.T) { + c := New() + orig := buildSandboxConfig + t.Cleanup(func() { buildSandboxConfig = orig }) + buildSandboxConfig = func(_ context.Context, _ string, _ string, _ *runhcsoptions.Options, _ *vmsandbox.Spec) (*hcsschema.ComputeSystem, *lcow.SandboxOptions, error) { + return nil, nil, errors.New("bad sandbox spec") + } + + err := c.CreateVM(ctx, opts) + if err == nil { + t.Fatal("expected error when buildHCSConfig fails") + } + if c.State() != StateNotCreated { + t.Errorf("expected StateNotCreated, got %s", c.State()) + } + }) + + t.Run("createVM_fails/state_stays_not_created", func(t *testing.T) { + c := New() + orig := buildSandboxConfig + t.Cleanup(func() { buildSandboxConfig = orig }) + buildSandboxConfig = func(_ context.Context, _ string, _ string, _ *runhcsoptions.Options, _ *vmsandbox.Spec) (*hcsschema.ComputeSystem, *lcow.SandboxOptions, error) { + return &hcsschema.ComputeSystem{}, &lcow.SandboxOptions{}, nil + } + + origCreate := createVM + t.Cleanup(func() { createVM = origCreate }) + createVM = func(_ context.Context, _ string, _ *hcsschema.ComputeSystem) (*vmmanager.UtilityVM, error) { + return nil, errors.New("HCS create failed") + } + + err := c.CreateVM(ctx, opts) + if err == nil { + t.Fatal("expected error when createVM fails") + } + if c.State() != StateNotCreated { + t.Errorf("expected StateNotCreated, got %s", c.State()) + } + }) +} + +// ─── StartVM: AddSecurityPolicy Error Path (LCOW-only) ──────────────────────── + +func TestStartVM_AddSecurityPolicyFails(t *testing.T) { + ctx := context.Background() + + c, uvm, guest := newControllerWithState(t, StateCreated) + uvm.EXPECT().RuntimeID().Return(testGUID).AnyTimes() + uvm.EXPECT().ID().Return("test-vm-id").AnyTimes() + + swapListenHVSock(t, func(_ *winio.HvsockAddr) (net.Listener, error) { + return &fakeListener{}, nil + }) + uvm.EXPECT().AcceptConnection(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeConn{}, nil).AnyTimes() + + guest.EXPECT().PrepareConnection(gomock.Any()).Return(nil) + uvm.EXPECT().Start(gomock.Any()).Return(nil) + uvm.EXPECT().Wait(gomock.Any()).Return(nil).AnyTimes() + guest.EXPECT().CreateConnection(gomock.Any()).Return(nil) + + // Set confidential config so buildConfidentialOptions returns non-nil. + c.sandboxOptions = &lcow.SandboxOptions{ + ConfidentialConfig: &lcow.ConfidentialConfig{ + SecurityPolicy: "test-policy", + SecurityPolicyEnforcer: "test-enforcer", + UvmReferenceInfoFile: "test-ref", + }, + } + + // Inject parseUVMReferenceInfo to succeed. + origParse := parseUVMReferenceInfo + t.Cleanup(func() { parseUVMReferenceInfo = origParse }) + parseUVMReferenceInfo = func(_ context.Context, _, _ string) (string, error) { + return "encoded-ref-info", nil + } + + // AddSecurityPolicy returns an error. + guest.EXPECT().AddSecurityPolicy(gomock.Any(), gomock.Any()).Return(errors.New("security policy failed")) + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when AddSecurityPolicy fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } +} + +// ─── StartVM: buildConfidentialOptions Error Path (LCOW-only) ────────────────── + +func TestStartVM_BuildConfidentialOptionsFails(t *testing.T) { + ctx := context.Background() + + c, uvm, guest := newControllerWithState(t, StateCreated) + uvm.EXPECT().RuntimeID().Return(testGUID).AnyTimes() + uvm.EXPECT().ID().Return("test-vm-id").AnyTimes() + + swapListenHVSock(t, func(_ *winio.HvsockAddr) (net.Listener, error) { + return &fakeListener{}, nil + }) + uvm.EXPECT().AcceptConnection(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeConn{}, nil).AnyTimes() + + guest.EXPECT().PrepareConnection(gomock.Any()).Return(nil) + uvm.EXPECT().Start(gomock.Any()).Return(nil) + uvm.EXPECT().Wait(gomock.Any()).Return(nil).AnyTimes() + guest.EXPECT().CreateConnection(gomock.Any()).Return(nil) + + // Set confidential config so buildConfidentialOptions is called. + c.sandboxOptions = &lcow.SandboxOptions{ + ConfidentialConfig: &lcow.ConfidentialConfig{ + UvmReferenceInfoFile: "test-ref", + }, + } + + // Inject parseUVMReferenceInfo to fail. + origParse := parseUVMReferenceInfo + t.Cleanup(func() { parseUVMReferenceInfo = origParse }) + parseUVMReferenceInfo = func(_ context.Context, _, _ string) (string, error) { + return "", errors.New("parse reference info failed") + } + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when buildConfidentialOptions fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } +} diff --git a/internal/controller/vm/vm_lcow.go b/internal/controller/vm/vm_lcow.go index 40aac30edd..b817f70ef8 100644 --- a/internal/controller/vm/vm_lcow.go +++ b/internal/controller/vm/vm_lcow.go @@ -21,6 +21,13 @@ import ( "golang.org/x/sync/errgroup" ) +var ( + // buildSandboxConfig builds the HCS compute system document for an LCOW sandbox. + buildSandboxConfig = lcow.BuildSandboxConfig + // parseUVMReferenceInfo reads and encodes UVM reference metadata. + parseUVMReferenceInfo = vmutils.ParseUVMReferenceInfo +) + // platformControllers holds platform-specific sub-controllers embedded in [Controller]. // For LCOW, this includes the Plan9 file share controller. type platformControllers struct { @@ -42,7 +49,7 @@ func (c *Controller) SandboxOptions() *lcow.SandboxOptions { // buildHCSConfig builds the HCS document for an LCOW VM by calling lcow.BuildSandboxConfig. // It also stores the sandbox options within the controller. func (c *Controller) buildHCSConfig(ctx context.Context, opts *CreateOptions) (*hcsschema.ComputeSystem, error) { - hcsDocument, sandboxOptions, err := lcow.BuildSandboxConfig(ctx, opts.Owner, opts.BundlePath, opts.ShimOpts, opts.SandboxSpec) + hcsDocument, sandboxOptions, err := buildSandboxConfig(ctx, opts.Owner, opts.BundlePath, opts.ShimOpts, opts.SandboxSpec) if err != nil { return nil, fmt.Errorf("failed to parse sandbox spec: %w", err) } @@ -59,7 +66,7 @@ func (c *Controller) buildConfidentialOptions(ctx context.Context) (*guestresour return nil, nil } - uvmReferenceInfoEncoded, err := vmutils.ParseUVMReferenceInfo( + uvmReferenceInfoEncoded, err := parseUVMReferenceInfo( ctx, vmutils.DefaultLCOWOSBootFilesPath(), c.sandboxOptions.ConfidentialConfig.UvmReferenceInfoFile, diff --git a/internal/controller/vm/vm_test.go b/internal/controller/vm/vm_test.go new file mode 100644 index 0000000000..d64304dcc4 --- /dev/null +++ b/internal/controller/vm/vm_test.go @@ -0,0 +1,1164 @@ +//go:build windows && (lcow || wcow) + +package vm + +import ( + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Microsoft/go-winio" + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/Microsoft/go-winio/pkg/process" + "github.com/Microsoft/hcsshim/internal/cmd" + "github.com/Microsoft/hcsshim/internal/controller/vm/mocks" + "github.com/Microsoft/hcsshim/internal/gcs" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/shimdiag" + iwin "github.com/Microsoft/hcsshim/internal/windows" + + "go.uber.org/mock/gomock" + "golang.org/x/sys/windows" +) + +// ─── helpers ─────────────────────────────────────────────────────────────────── + +// newControllerWithState returns a Controller wired to gomock uvm/guest mocks +// in the given state. The caller can set expectations on the returned mocks. +func newControllerWithState(t *testing.T, state State) (*Controller, *mocks.MockutilityVM, *mocks.MockguestManager) { + t.Helper() + ctrl := gomock.NewController(t) + uvm := mocks.NewMockutilityVM(ctrl) + guest := mocks.NewMockguestManager(ctrl) + c := &Controller{ + uvm: uvm, + guest: guest, + vmState: state, + logOutputDone: make(chan struct{}), + } + return c, uvm, guest +} + +// stubCaps is a test-only implementation of [gcs.GuestDefinedCapabilities]. +type stubCaps struct{ dumpStacksSupported bool } + +func (s stubCaps) IsSignalProcessSupported() bool { return false } +func (s stubCaps) IsDeleteContainerStateSupported() bool { return false } +func (s stubCaps) IsDumpStacksSupported() bool { return s.dumpStacksSupported } +func (s stubCaps) IsNamespaceAddRequestSupported() bool { return false } + +var _ gcs.GuestDefinedCapabilities = stubCaps{} + +// fakeListener satisfies net.Listener for testing hvsock setup. +type fakeListener struct{ closed atomic.Bool } + +func (f *fakeListener) Accept() (net.Conn, error) { return nil, errors.New("not implemented") } +func (f *fakeListener) Close() error { f.closed.Store(true); return nil } +func (f *fakeListener) Addr() net.Addr { return nil } + +// fakeConn satisfies net.Conn for testing hvsock connections. +type fakeConn struct{} + +func (f *fakeConn) Read([]byte) (int, error) { return 0, io.EOF } +func (f *fakeConn) Write(b []byte) (int, error) { return len(b), nil } +func (f *fakeConn) Close() error { return nil } +func (f *fakeConn) LocalAddr() net.Addr { return nil } +func (f *fakeConn) RemoteAddr() net.Addr { return nil } +func (f *fakeConn) SetDeadline(time.Time) error { return nil } +func (f *fakeConn) SetReadDeadline(time.Time) error { return nil } +func (f *fakeConn) SetWriteDeadline(time.Time) error { return nil } + +// swapListenHVSock replaces the package-level listenHVSock for the duration of t. +func swapListenHVSock(t *testing.T, fn func(*winio.HvsockAddr) (net.Listener, error)) { + t.Helper() + orig := listenHVSock + t.Cleanup(func() { listenHVSock = orig }) + listenHVSock = fn +} + +// swapLookupVMMEM replaces the package-level lookupVMMEM for the duration of t. +func swapLookupVMMEM(t *testing.T, fn func(context.Context, guid.GUID, iwin.API) (windows.Handle, error)) { + t.Helper() + orig := lookupVMMEM + t.Cleanup(func() { lookupVMMEM = orig }) + lookupVMMEM = fn +} + +// swapGetProcessMemoryInfo replaces the package-level getProcessMemoryInfo for the duration of t. +func swapGetProcessMemoryInfo(t *testing.T, fn func(windows.Handle) (*process.ProcessMemoryCountersEx, error)) { + t.Helper() + orig := getProcessMemoryInfo + t.Cleanup(func() { getProcessMemoryInfo = orig }) + getProcessMemoryInfo = fn +} + +// testGUID is a fixed GUID for test assertions. +var testGUID = guid.GUID{Data1: 0xDEADBEEF} + +// ─── 1. State Machine Guards ─────────────────────────────────────────────────── + +func TestStateGuards(t *testing.T) { + ctx := context.Background() + + t.Run("CreateVM/already_created", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateCreated) + err := c.CreateVM(ctx, &CreateOptions{ID: "test-vm"}) + if err == nil { + t.Error("expected error for CreateVM on already-Created controller") + } + }) + + t.Run("StartVM/not_created", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateNotCreated) + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Error("expected error for StartVM on NotCreated controller") + } + }) + + t.Run("StartVM/already_running_idempotent", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateRunning) + err := c.StartVM(ctx, &StartOptions{}) + if err != nil { + t.Errorf("expected nil for StartVM on already-Running controller, got: %v", err) + } + }) + + t.Run("TerminateVM/not_created_idempotent", func(t *testing.T) { + c := New() + err := c.TerminateVM(ctx) + if err != nil { + t.Errorf("expected nil for TerminateVM on NotCreated controller, got: %v", err) + } + }) + + t.Run("TerminateVM/already_terminated_idempotent", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateTerminated) + err := c.TerminateVM(ctx) + if err != nil { + t.Errorf("expected nil for TerminateVM on Terminated controller, got: %v", err) + } + }) + + // Table-driven: methods that require StateRunning. + runningOnlyTests := []struct { + name string + call func(*Controller) error + }{ + { + name: "UpdateCPU", + call: func(c *Controller) error { + return c.UpdateCPU(ctx, &hcsschema.ProcessorLimits{}) + }, + }, + { + name: "UpdateMemory", + call: func(c *Controller) error { + return c.UpdateMemory(ctx, 1024) + }, + }, + { + name: "UpdateCPUGroup", + call: func(c *Controller) error { + return c.UpdateCPUGroup(ctx, "some-group-id") + }, + }, + { + name: "UpdatePolicyFragment", + call: func(c *Controller) error { + return c.UpdatePolicyFragment(ctx, guestresource.SecurityPolicyFragment{}) + }, + }, + { + name: "DumpStacks", + call: func(c *Controller) error { + _, err := c.DumpStacks(ctx) + return err + }, + }, + { + name: "Stats", + call: func(c *Controller) error { + _, err := c.Stats(ctx) + return err + }, + }, + { + name: "ExecIntoHost", + call: func(c *Controller) error { + _, err := c.ExecIntoHost(ctx, &shimdiag.ExecProcessRequest{Args: []string{"echo"}}) + return err + }, + }, + } + for _, tc := range runningOnlyTests { + t.Run(tc.name+"/not_running", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateCreated) + if err := tc.call(c); err == nil { + t.Errorf("expected error for %s on non-Running controller", tc.name) + } + }) + } + + t.Run("Wait/not_created", func(t *testing.T) { + c := New() + err := c.Wait(ctx) + if err == nil { + t.Error("expected error for Wait on NotCreated controller") + } + }) + + t.Run("ExitStatus/not_terminated", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateRunning) + _, err := c.ExitStatus() + if err == nil { + t.Error("expected error for ExitStatus on non-Terminated controller") + } + }) +} + +// ─── 2. TerminateVM Cleanup Chain ────────────────────────────────────────────── + +func TestTerminateVM(t *testing.T) { + ctx := context.Background() + + t.Run("success/running_to_terminated", func(t *testing.T) { + c, uvm, guest := newControllerWithState(t, StateRunning) + + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(nil) + + if err := c.TerminateVM(ctx); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if c.State() != StateTerminated { + t.Errorf("expected StateTerminated, got %s", c.State()) + } + }) + + t.Run("close_connection_error_logged_not_returned", func(t *testing.T) { + c, uvm, guest := newControllerWithState(t, StateRunning) + + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(errors.New("guest connection error")) + uvm.EXPECT().Close(gomock.Any()).Return(nil) + + if err := c.TerminateVM(ctx); err != nil { + t.Errorf("expected nil (CloseConnection error should be logged, not returned), got: %v", err) + } + if c.State() != StateTerminated { + t.Errorf("expected StateTerminated, got %s", c.State()) + } + }) + + t.Run("close_fails/state_invalid", func(t *testing.T) { + c, uvm, guest := newControllerWithState(t, StateRunning) + + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(errors.New("close failed")) + + err := c.TerminateVM(ctx) + if err == nil { + t.Fatal("expected error when uvm.Close fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid after Close failure, got %s", c.State()) + } + }) + + t.Run("double_terminate/second_is_noop", func(t *testing.T) { + c, uvm, guest := newControllerWithState(t, StateRunning) + + // First terminate succeeds. + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(nil) + + if err := c.TerminateVM(ctx); err != nil { + t.Fatalf("first TerminateVM: %v", err) + } + + // Second call — no mock expectations needed; should return nil immediately. + if err := c.TerminateVM(ctx); err != nil { + t.Errorf("second TerminateVM should be nil, got: %v", err) + } + }) + + t.Run("terminate_from_invalid/recovers", func(t *testing.T) { + // Simulate: Close failed → StateInvalid, then TerminateVM again. + c, uvm, guest := newControllerWithState(t, StateRunning) + + // First call: Close fails → Invalid. + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(errors.New("close failed")) + + _ = c.TerminateVM(ctx) + if c.State() != StateInvalid { + t.Fatalf("precondition: expected StateInvalid, got %s", c.State()) + } + + // Second call from Invalid: should attempt cleanup again. + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(nil) + + if err := c.TerminateVM(ctx); err != nil { + t.Errorf("TerminateVM from Invalid should succeed, got: %v", err) + } + if c.State() != StateTerminated { + t.Errorf("expected StateTerminated, got %s", c.State()) + } + }) +} + +// ─── 3. StartVM Error Cascade ────────────────────────────────────────────────── + +// setupStartVMEnv injects fakes for listenHVSock and AcceptConnection so that +// the entropy/logging hvsock setup succeeds. Returns the created controller. +func setupStartVMEnv(t *testing.T) (*Controller, *mocks.MockutilityVM, *mocks.MockguestManager) { + t.Helper() + + c, uvm, guest := newControllerWithState(t, StateCreated) + + // listenHVSock always returns a fakeListener. + swapListenHVSock(t, func(_ *winio.HvsockAddr) (net.Listener, error) { + return &fakeListener{}, nil + }) + + // AcceptConnection returns a fakeConn (for entropy write and log relay). + uvm.EXPECT().AcceptConnection(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeConn{}, nil).AnyTimes() + + // RuntimeID is called by listenHVSock setup (for building the HvsockAddr). + uvm.EXPECT().RuntimeID().Return(testGUID).AnyTimes() + // ID is called by ParseGCSLogrus. + uvm.EXPECT().ID().Return("test-vm-id").AnyTimes() + + return c, uvm, guest +} + +func TestStartVM(t *testing.T) { + ctx := context.Background() + + t.Run("prepare_connection_fails/state_invalid", func(t *testing.T) { + c, uvm, guest := setupStartVMEnv(t) + _ = uvm // AcceptConnection/RuntimeID already set up + + guest.EXPECT().PrepareConnection(gomock.Any()).Return(errors.New("prepare failed")) + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when PrepareConnection fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } + }) + + t.Run("uvm_start_fails/state_invalid", func(t *testing.T) { + c, uvm, guest := setupStartVMEnv(t) + + guest.EXPECT().PrepareConnection(gomock.Any()).Return(nil) + uvm.EXPECT().Start(gomock.Any()).Return(errors.New("start failed")) + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when uvm.Start fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } + }) + + t.Run("create_connection_fails/state_invalid", func(t *testing.T) { + c, uvm, guest := setupStartVMEnv(t) + + guest.EXPECT().PrepareConnection(gomock.Any()).Return(nil) + uvm.EXPECT().Start(gomock.Any()).Return(nil) + uvm.EXPECT().Wait(gomock.Any()).Return(nil).AnyTimes() + guest.EXPECT().CreateConnection(gomock.Any()).Return(errors.New("create conn failed")) + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when CreateConnection fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } + }) + + t.Run("add_security_policy_fails/state_invalid", func(t *testing.T) { + // Covered by TestStartVM_AddSecurityPolicyFails in vm_createvm_lcow_test.go + // (requires LCOW-specific sandbox options that can't be set portably). + t.Skip("covered by LCOW-specific test file") + }) + + t.Run("full_success/state_running", func(t *testing.T) { + // On WCOW, finalizeGCSConnection type-asserts c.guest to + // *guestmanager.Guest (for UpdateHvSocketAddress, which is + // WCOW-only). This panics with a mock. The LCOW path is a no-op. + // WCOW's finalizeGCSConnection is tested at the Guest level. + if !isLCOW() { + t.Skip("WCOW finalizeGCSConnection requires concrete Guest (type assertion)") + } + c, uvm, guest := setupStartVMEnv(t) + + guest.EXPECT().PrepareConnection(gomock.Any()).Return(nil) + uvm.EXPECT().Start(gomock.Any()).Return(nil) + + // Block the background waitForVMExit goroutine so it doesn't + // race with the state assertion below. Without this, the mock + // Wait returns immediately and the goroutine transitions state + // to Terminated before we can check StateRunning. + waitCh := make(chan struct{}) + t.Cleanup(func() { close(waitCh) }) + uvm.EXPECT().Wait(gomock.Any()).DoAndReturn(func(ctx context.Context) error { + <-waitCh + return nil + }).AnyTimes() + + guest.EXPECT().CreateConnection(gomock.Any()).Return(nil) + + err := c.StartVM(ctx, &StartOptions{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if c.State() != StateRunning { + t.Errorf("expected StateRunning, got %s", c.State()) + } + }) + + t.Run("entropy_listener_bind_fails/state_invalid", func(t *testing.T) { + if !isLCOW() { + t.Skip("WCOW setupEntropyListener is a no-op") + } + c, uvm, _ := newControllerWithState(t, StateCreated) + uvm.EXPECT().RuntimeID().Return(testGUID).AnyTimes() + + // First listenHVSock call (entropy) fails. + swapListenHVSock(t, func(_ *winio.HvsockAddr) (net.Listener, error) { + return nil, errors.New("entropy bind failed") + }) + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when entropy listener bind fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } + }) + + t.Run("logging_listener_bind_fails/state_invalid", func(t *testing.T) { + if !isLCOW() { + t.Skip("WCOW setupLoggingListener uses a standalone goroutine, not errgroup") + } + c, uvm, _ := newControllerWithState(t, StateCreated) + uvm.EXPECT().RuntimeID().Return(testGUID).AnyTimes() + // The entropy goroutine is already dispatched on the errgroup + // before the logging listener bind fails; provide AcceptConnection + // for it so it doesn't panic. + uvm.EXPECT().AcceptConnection(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeConn{}, nil).AnyTimes() + + // First call (entropy) succeeds, second call (logging) fails. + callCount := 0 + swapListenHVSock(t, func(_ *winio.HvsockAddr) (net.Listener, error) { + callCount++ + if callCount == 1 { + return &fakeListener{}, nil + } + return nil, errors.New("logging bind failed") + }) + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when logging listener bind fails") + } + // logOutputDone should be closed on logging listener failure. + select { + case <-c.logOutputDone: + default: + t.Error("expected logOutputDone to be closed when logging listener fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } + }) + + t.Run("entropy_accept_fails/errgroup_error", func(t *testing.T) { + if !isLCOW() { + t.Skip("WCOW setupEntropyListener is a no-op") + } + c, uvm, guest := newControllerWithState(t, StateCreated) + uvm.EXPECT().RuntimeID().Return(testGUID).AnyTimes() + uvm.EXPECT().ID().Return("test-vm-id").AnyTimes() + + swapListenHVSock(t, func(_ *winio.HvsockAddr) (net.Listener, error) { + return &fakeListener{}, nil + }) + // AcceptConnection fails (entropy accept error surfaces via errgroup). + uvm.EXPECT().AcceptConnection(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("entropy accept failed")).AnyTimes() + + guest.EXPECT().PrepareConnection(gomock.Any()).Return(nil) + uvm.EXPECT().Start(gomock.Any()).Return(nil) + uvm.EXPECT().Wait(gomock.Any()).Return(nil).AnyTimes() + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when entropy AcceptConnection fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } + }) + + t.Run("logging_accept_fails/state_invalid", func(t *testing.T) { + if !isLCOW() { + t.Skip("WCOW setupLoggingListener uses a standalone goroutine, not errgroup") + } + c, uvm, guest := newControllerWithState(t, StateCreated) + uvm.EXPECT().RuntimeID().Return(testGUID).AnyTimes() + uvm.EXPECT().ID().Return("test-vm-id").AnyTimes() + + swapListenHVSock(t, func(_ *winio.HvsockAddr) (net.Listener, error) { + return &fakeListener{}, nil + }) + // First AcceptConnection (entropy) succeeds, second (logging) fails. + acceptCount := 0 + uvm.EXPECT().AcceptConnection(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ net.Listener, _ bool) (net.Conn, error) { + acceptCount++ + if acceptCount == 1 { + return &fakeConn{}, nil + } + return nil, errors.New("logging accept failed") + }).AnyTimes() + + guest.EXPECT().PrepareConnection(gomock.Any()).Return(nil) + uvm.EXPECT().Start(gomock.Any()).Return(nil) + uvm.EXPECT().Wait(gomock.Any()).Return(nil).AnyTimes() + + err := c.StartVM(ctx, &StartOptions{}) + if err == nil { + t.Fatal("expected error when logging AcceptConnection fails") + } + if c.State() != StateInvalid { + t.Errorf("expected StateInvalid, got %s", c.State()) + } + }) +} + +// ─── 4. waitForVMExit Race ───────────────────────────────────────────────────── + +func TestWaitForVMExit(t *testing.T) { + ctx := context.Background() + + t.Run("natural_exit/sets_terminated", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + + waitDone := make(chan struct{}) + uvm.EXPECT().Wait(gomock.Any()).DoAndReturn(func(context.Context) error { + <-waitDone + return nil + }) + + go c.waitForVMExit(ctx) + + // Unblock the Wait. + close(waitDone) + + // Give the goroutine a moment to acquire the lock and set state. + time.Sleep(50 * time.Millisecond) + + if c.State() != StateTerminated { + t.Errorf("expected StateTerminated after natural exit, got %s", c.State()) + } + }) + + t.Run("concurrent_terminate/waitForVMExit_noops", func(t *testing.T) { + c, uvm, guest := newControllerWithState(t, StateRunning) + + waitCh := make(chan struct{}) + uvm.EXPECT().Wait(gomock.Any()).DoAndReturn(func(context.Context) error { + <-waitCh + return nil + }).AnyTimes() + + go c.waitForVMExit(ctx) + + // TerminateVM runs first and sets Terminated. + uvm.EXPECT().Terminate(gomock.Any()).Return(nil) + guest.EXPECT().CloseConnection().Return(nil) + uvm.EXPECT().Close(gomock.Any()).Return(nil) + + if err := c.TerminateVM(ctx); err != nil { + t.Fatalf("TerminateVM: %v", err) + } + + // Now unblock waitForVMExit — it should see Terminated and no-op. + close(waitCh) + time.Sleep(50 * time.Millisecond) + + if c.State() != StateTerminated { + t.Errorf("expected StateTerminated, got %s", c.State()) + } + }) +} + +// ─── 5. ExecIntoHost ─────────────────────────────────────────────────────────── + +func TestExecIntoHost(t *testing.T) { + ctx := context.Background() + + t.Run("terminal_with_stderr/returns_error", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateRunning) + _, err := c.ExecIntoHost(ctx, &shimdiag.ExecProcessRequest{ + Terminal: true, + Stderr: "/some/path", + }) + if err == nil { + t.Error("expected error when terminal=true and stderr is set") + } + }) + + t.Run("success/returns_exit_code", func(t *testing.T) { + c, _, guest := newControllerWithState(t, StateRunning) + guest.EXPECT().ExecIntoUVM(gomock.Any(), gomock.Any()).Return(42, nil) + + exitCode, err := c.ExecIntoHost(ctx, &shimdiag.ExecProcessRequest{ + Args: []string{"echo", "hello"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exitCode != 42 { + t.Errorf("expected exit code 42, got %d", exitCode) + } + }) + + t.Run("guest_error/propagated", func(t *testing.T) { + c, _, guest := newControllerWithState(t, StateRunning) + guest.EXPECT().ExecIntoUVM(gomock.Any(), gomock.Any()).Return(-1, errors.New("exec failed")) + + _, err := c.ExecIntoHost(ctx, &shimdiag.ExecProcessRequest{ + Args: []string{"fail"}, + }) + if err == nil { + t.Error("expected error from guest.ExecIntoUVM to be propagated") + } + }) + + t.Run("active_exec_count_tracking", func(t *testing.T) { + c, _, guest := newControllerWithState(t, StateRunning) + + // Block until we verify the count. + execStarted := make(chan struct{}) + execContinue := make(chan struct{}) + guest.EXPECT().ExecIntoUVM(gomock.Any(), gomock.Any()).DoAndReturn( + func(context.Context, *cmd.CmdProcessRequest) (int, error) { + close(execStarted) + <-execContinue + return 0, nil + }, + ) + + go func() { + _, _ = c.ExecIntoHost(ctx, &shimdiag.ExecProcessRequest{Args: []string{"sleep"}}) + }() + + <-execStarted + if count := c.activeExecCount.Load(); count != 1 { + t.Errorf("expected activeExecCount=1, got %d", count) + } + + close(execContinue) + time.Sleep(50 * time.Millisecond) + + if count := c.activeExecCount.Load(); count != 0 { + t.Errorf("expected activeExecCount=0 after exec, got %d", count) + } + }) +} + +// ─── 6. DumpStacks ───────────────────────────────────────────────────────────── + +func TestDumpStacks(t *testing.T) { + ctx := context.Background() + + t.Run("supported/calls_guest", func(t *testing.T) { + c, _, guest := newControllerWithState(t, StateRunning) + guest.EXPECT().Capabilities().Return(stubCaps{dumpStacksSupported: true}) + guest.EXPECT().DumpStacks(gomock.Any()).Return("goroutine 1 [running]:\n...", nil) + + result, err := c.DumpStacks(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == "" { + t.Error("expected non-empty stack dump") + } + }) + + t.Run("unsupported/returns_empty", func(t *testing.T) { + c, _, guest := newControllerWithState(t, StateRunning) + guest.EXPECT().Capabilities().Return(stubCaps{dumpStacksSupported: false}) + // DumpStacks should NOT be called. + + result, err := c.DumpStacks(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "" { + t.Errorf("expected empty result when unsupported, got: %q", result) + } + }) + + t.Run("guest_error/propagated", func(t *testing.T) { + c, _, guest := newControllerWithState(t, StateRunning) + guest.EXPECT().Capabilities().Return(stubCaps{dumpStacksSupported: true}) + guest.EXPECT().DumpStacks(gomock.Any()).Return("", errors.New("dump failed")) + + _, err := c.DumpStacks(ctx) + if err == nil { + t.Error("expected error from guest.DumpStacks to be propagated") + } + }) +} + +// ─── 7. Wait ─────────────────────────────────────────────────────────────────── + +func TestWait(t *testing.T) { + t.Run("not_created/returns_error", func(t *testing.T) { + c := New() + err := c.Wait(context.Background()) + if err == nil { + t.Error("expected error for Wait on NotCreated controller") + } + }) + + t.Run("vm_exits_normally", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + close(c.logOutputDone) // simulate log processing already done + + uvm.EXPECT().Wait(gomock.Any()).Return(nil) + + err := c.Wait(context.Background()) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("context_cancelled/joined_error", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + // logOutputDone never closes → ctx cancellation should produce an error. + + uvm.EXPECT().Wait(gomock.Any()).Return(nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + err := c.Wait(ctx) + if err == nil { + t.Error("expected error when context is cancelled and logOutputDone is not closed") + } + }) + + t.Run("uvm_wait_error/propagated", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + close(c.logOutputDone) + + uvm.EXPECT().Wait(gomock.Any()).Return(errors.New("vm crashed")) + + err := c.Wait(context.Background()) + if err == nil { + t.Error("expected error when uvm.Wait fails") + } + }) + + t.Run("terminated_vm/immediate_return", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateTerminated) + close(c.logOutputDone) + + uvm.EXPECT().Wait(gomock.Any()).Return(nil) + + err := c.Wait(context.Background()) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) +} + +// ─── 8. Stats ────────────────────────────────────────────────────────────────── + +func TestStats(t *testing.T) { + ctx := context.Background() + + t.Run("not_running/returns_error", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateCreated) + _, err := c.Stats(ctx) + if err == nil { + t.Error("expected error for Stats on non-Running controller") + } + }) + + t.Run("lookupVMMEM_fails/error_propagated", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + uvm.EXPECT().RuntimeID().Return(testGUID) + + swapLookupVMMEM(t, func(_ context.Context, _ guid.GUID, _ iwin.API) (windows.Handle, error) { + return 0, errors.New("vmmem not found") + }) + + _, err := c.Stats(ctx) + if err == nil { + t.Error("expected error when lookupVMMEM fails") + } + }) + + t.Run("propertiesV2_fails/error_propagated", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + uvm.EXPECT().RuntimeID().Return(testGUID) + uvm.EXPECT().PropertiesV2(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("properties failed")) + + swapLookupVMMEM(t, func(_ context.Context, _ guid.GUID, _ iwin.API) (windows.Handle, error) { + return windows.Handle(0x1234), nil + }) + + _, err := c.Stats(ctx) + if err == nil { + t.Error("expected error when PropertiesV2 fails") + } + }) + + t.Run("va_backed/working_set_from_memcounters", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + c.isPhysicallyBacked = false + + uvm.EXPECT().RuntimeID().Return(testGUID) + uvm.EXPECT().PropertiesV2(gomock.Any(), gomock.Any(), gomock.Any()).Return(&hcsschema.Properties{ + Statistics: &hcsschema.Statistics{ + Processor: &hcsschema.ProcessorStats{ + TotalRuntime100ns: 5000, + }, + }, + Memory: &hcsschema.MemoryInformationForVm{ + VirtualMachineMemory: &hcsschema.VmMemory{ + AssignedMemory: 1024, + }, + }, + }, nil) + + const fakeWSS uint = 8192 + swapLookupVMMEM(t, func(_ context.Context, _ guid.GUID, _ iwin.API) (windows.Handle, error) { + return windows.Handle(0x1234), nil + }) + swapGetProcessMemoryInfo(t, func(_ windows.Handle) (*process.ProcessMemoryCountersEx, error) { + return &process.ProcessMemoryCountersEx{WorkingSetSize: fakeWSS}, nil + }) + + s, err := c.Stats(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s.Memory == nil { + t.Fatal("expected non-nil memory stats") + } + if s.Memory.WorkingSetBytes != uint64(fakeWSS) { + t.Errorf("expected WorkingSetBytes=%d (from memCounters), got %d", fakeWSS, s.Memory.WorkingSetBytes) + } + if s.Processor == nil || s.Processor.TotalRuntimeNS != 500000 { + t.Errorf("expected TotalRuntimeNS=500000, got %d", s.Processor.TotalRuntimeNS) + } + }) + + t.Run("physically_backed/working_set_from_assigned_memory", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + c.isPhysicallyBacked = true + + uvm.EXPECT().RuntimeID().Return(testGUID) + uvm.EXPECT().PropertiesV2(gomock.Any(), gomock.Any(), gomock.Any()).Return(&hcsschema.Properties{ + Statistics: &hcsschema.Statistics{ + Processor: &hcsschema.ProcessorStats{ + TotalRuntime100ns: 100, + }, + }, + Memory: &hcsschema.MemoryInformationForVm{ + VirtualMachineMemory: &hcsschema.VmMemory{ + AssignedMemory: 256, + }, + }, + }, nil) + + swapLookupVMMEM(t, func(_ context.Context, _ guid.GUID, _ iwin.API) (windows.Handle, error) { + return windows.Handle(0x5678), nil + }) + // getProcessMemoryInfo should NOT be called for physically-backed VMs. + // If it is, the test will fail because we haven't set an expectation. + + s, err := c.Stats(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s.Memory == nil { + t.Fatal("expected non-nil memory stats") + } + // AssignedMemory * 4096 + expected := uint64(256 * 4096) + if s.Memory.WorkingSetBytes != expected { + t.Errorf("expected WorkingSetBytes=%d (AssignedMemory*4096), got %d", expected, s.Memory.WorkingSetBytes) + } + }) + + t.Run("get_process_memory_info_fails/error_propagated", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + c.isPhysicallyBacked = false + + uvm.EXPECT().RuntimeID().Return(testGUID) + uvm.EXPECT().PropertiesV2(gomock.Any(), gomock.Any(), gomock.Any()).Return(&hcsschema.Properties{ + Statistics: &hcsschema.Statistics{ + Processor: &hcsschema.ProcessorStats{}, + }, + Memory: &hcsschema.MemoryInformationForVm{ + VirtualMachineMemory: &hcsschema.VmMemory{}, + }, + }, nil) + + swapLookupVMMEM(t, func(_ context.Context, _ guid.GUID, _ iwin.API) (windows.Handle, error) { + return windows.Handle(0x1234), nil + }) + swapGetProcessMemoryInfo(t, func(_ windows.Handle) (*process.ProcessMemoryCountersEx, error) { + return nil, errors.New("memory info failed") + }) + + _, err := c.Stats(ctx) + if err == nil { + t.Error("expected error when getProcessMemoryInfo fails") + } + }) +} + +// ─── 9. Accessor Methods ────────────────────────────────────────────────────── + +func TestRuntimeID(t *testing.T) { + t.Run("created/returns_guid_string", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateCreated) + uvm.EXPECT().RuntimeID().Return(testGUID) + + rid := c.RuntimeID() + if rid == "" { + t.Error("expected non-empty RuntimeID for Created controller") + } + }) + + t.Run("not_created/returns_empty", func(t *testing.T) { + c := New() + if rid := c.RuntimeID(); rid != "" { + t.Errorf("expected empty RuntimeID for NotCreated controller, got %q", rid) + } + }) +} + +func TestStartTime(t *testing.T) { + t.Run("running/returns_start_time", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + now := time.Now() + uvm.EXPECT().StartedTime().Return(now) + + st := c.StartTime() + if st != now { + t.Errorf("expected start time %v, got %v", now, st) + } + }) + + t.Run("not_created/returns_zero", func(t *testing.T) { + c := New() + st := c.StartTime() + if !st.IsZero() { + t.Errorf("expected zero start time for NotCreated, got %v", st) + } + }) +} + +func TestExitStatus(t *testing.T) { + t.Run("terminated/returns_status", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateTerminated) + now := time.Now() + exitErr := errors.New("vm crashed") + uvm.EXPECT().StoppedTime().Return(now) + uvm.EXPECT().ExitError().Return(exitErr) + + es, err := c.ExitStatus() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if es.StoppedTime != now { + t.Errorf("expected StoppedTime %v, got %v", now, es.StoppedTime) + } + if !errors.Is(es.Err, exitErr) { + t.Errorf("expected ExitError %v, got %v", exitErr, es.Err) + } + }) + + t.Run("running/returns_error", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateRunning) + _, err := c.ExitStatus() + if err == nil { + t.Error("expected error for ExitStatus on Running controller") + } + }) +} + +// ─── 10. UpdateCPUGroup ──────────────────────────────────────────────────────── + +func TestUpdateCPUGroup(t *testing.T) { + ctx := context.Background() + + t.Run("empty_id/returns_error", func(t *testing.T) { + c, _, _ := newControllerWithState(t, StateRunning) + err := c.UpdateCPUGroup(ctx, "") + if err == nil { + t.Error("expected error when cpuGroupID is empty") + } + }) + + t.Run("success", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + uvm.EXPECT().SetCPUGroup(gomock.Any(), gomock.Any()).Return(nil) + + err := c.UpdateCPUGroup(ctx, "group-123") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("set_cpu_group_fails/error_propagated", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + uvm.EXPECT().SetCPUGroup(gomock.Any(), gomock.Any()).Return(errors.New("set group failed")) + + err := c.UpdateCPUGroup(ctx, "group-123") + if err == nil { + t.Error("expected error when SetCPUGroup fails") + } + }) +} + +// ─── 12. UpdateCPU ───────────────────────────────────────────────────────────── + +func TestUpdateCPU(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + uvm.EXPECT().UpdateCPULimits(gomock.Any(), gomock.Any()).Return(nil) + + err := c.UpdateCPU(ctx, &hcsschema.ProcessorLimits{}) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("uvm_error/propagated", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + uvm.EXPECT().UpdateCPULimits(gomock.Any(), gomock.Any()).Return(errors.New("cpu limits failed")) + + err := c.UpdateCPU(ctx, &hcsschema.ProcessorLimits{}) + if err == nil { + t.Error("expected error when UpdateCPULimits fails") + } + }) +} + +// ─── 13. UpdateMemory ────────────────────────────────────────────────────────── + +func TestUpdateMemory(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + uvm.EXPECT().UpdateMemory(gomock.Any(), gomock.Any()).Return(nil) + + err := c.UpdateMemory(ctx, 1024) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("uvm_error/propagated", func(t *testing.T) { + c, uvm, _ := newControllerWithState(t, StateRunning) + uvm.EXPECT().UpdateMemory(gomock.Any(), gomock.Any()).Return(errors.New("memory update failed")) + + err := c.UpdateMemory(ctx, 1024) + if err == nil { + t.Error("expected error when UpdateMemory fails") + } + }) +} + +// ─── 14. UpdatePolicyFragment ────────────────────────────────────────────────── + +func TestUpdatePolicyFragment(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + c, _, guest := newControllerWithState(t, StateRunning) + guest.EXPECT().InjectPolicyFragment(gomock.Any(), gomock.Any()).Return(nil) + + err := c.UpdatePolicyFragment(ctx, guestresource.SecurityPolicyFragment{}) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("guest_error/propagated", func(t *testing.T) { + c, _, guest := newControllerWithState(t, StateRunning) + guest.EXPECT().InjectPolicyFragment(gomock.Any(), gomock.Any()).Return(errors.New("inject failed")) + + err := c.UpdatePolicyFragment(ctx, guestresource.SecurityPolicyFragment{}) + if err == nil { + t.Error("expected error when InjectPolicyFragment fails") + } + }) +} + +// ─── 15. Concurrent access sanity ────────────────────────────────────────────── + +func TestConcurrentStateAccess(t *testing.T) { + c, uvm, guest := newControllerWithState(t, StateRunning) + ctx := context.Background() + + // Allow TerminateVM to be called from one goroutine. + uvm.EXPECT().Terminate(gomock.Any()).Return(nil).AnyTimes() + guest.EXPECT().CloseConnection().Return(nil).AnyTimes() + uvm.EXPECT().Close(gomock.Any()).Return(nil).AnyTimes() + + var wg sync.WaitGroup + // Read state concurrently while another goroutine terminates. + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = c.State() + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + _ = c.TerminateVM(ctx) + }() + + wg.Wait() + // No race detector panic = pass. +} diff --git a/internal/controller/vm/vm_wcow.go b/internal/controller/vm/vm_wcow.go index 42303b82a0..de6d7f5b7a 100644 --- a/internal/controller/vm/vm_wcow.go +++ b/internal/controller/vm/vm_wcow.go @@ -53,7 +53,7 @@ func (c *Controller) setupEntropyListener(_ context.Context, _ *errgroup.Group) // to prevent resource exhaustion, but will accept new connections if the current one is closed. // This supports scenarios where the logging service inside the VM needs to restart. func (c *Controller) setupLoggingListener(ctx context.Context, _ *errgroup.Group) error { - baseListener, err := winio.ListenHvsock(&winio.HvsockAddr{ + baseListener, err := listenHVSock(&winio.HvsockAddr{ VMID: c.uvm.RuntimeID(), ServiceID: prot.WindowsLoggingHvsockServiceID, })