diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 43669a6f26..0cd6448cae 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -116,6 +116,44 @@ type Host struct { // hostMounts keeps the state of currently mounted devices and file systems, // which is used for GCS hardening. hostMounts *hostMounts + // uvmError contains a permanent flag to indicate that further mounts, + // unmounts and container creation / deletion should not be allowed. This + // is set when, because of a failure during an unmount operation, we end up + // in a state where the policy enforcer's state is out of sync with what we + // have actually done, but we cannot safely revert its state. + // + // Only consulted in confidential mode (see Host.checkState). + uvmError uvmConsistencyError +} + +type uvmConsistencyError struct { + mu sync.Mutex + // The error describing why the UVM has entered an inconsistent state. If + // this is nil, there is no error and Check() returns nil. + cause error +} + +// Mark that the UVM has entered an inconsistent state, and store the cause if +// it is not already set. +func (u *uvmConsistencyError) Set(cause error) { + u.mu.Lock() + defer u.mu.Unlock() + if u.cause == nil { + u.cause = cause + } +} + +// Check returns a non-nil error if the UVM has been marked inconsistent. +func (u *uvmConsistencyError) Check() error { + u.mu.Lock() + defer u.mu.Unlock() + if u.cause == nil { + return nil + } + return fmt.Errorf( + "mount, unmount, container creation and deletion have been disabled in this UVM due to a previous error: %w", + u.cause, + ) } func NewHost(rtime runtime.Runtime, vsock transport.Transport, initialEnforcer securitypolicy.SecurityPolicyEnforcer, logWriter io.Writer) *Host { @@ -135,6 +173,7 @@ func NewHost(rtime runtime.Runtime, vsock transport.Transport, initialEnforcer s devNullTransport: &transport.DevNullTransport{}, hostMounts: newHostMounts(), securityOptions: securityPolicyOptions, + uvmError: uvmConsistencyError{}, } } @@ -386,7 +425,43 @@ func checkContainerSettings(sandboxID, containerID string, settings *prot.VMHost return nil } +// checkState returns an error if the UVM has entered an inconsistent state from +// which it cannot safely recover. Only enforced for confidential containers. +func (h *Host) checkState() error { + if h.HasSecurityPolicy() { + return h.uvmError.Check() + } + return nil +} + +// setUVMInconsistent records that the UVM has entered an inconsistent state and +// logs the cause. The flag is only consulted in confidential mode (see +// checkState), so this is a no-op for non-confidential hosts. +func (h *Host) setUVMInconsistent(cause error) { + if !h.HasSecurityPolicy() { + return + } + h.uvmError.Set(cause) + log.G(context.Background()).WithFields(logrus.Fields{ + "cause": cause, + }).Error("Host marked inconsistent. All further mounts/unmounts, container creation and deletion will fail.") +} + +func checkExists(path string) (bool, error) { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, errors.Wrapf(err, "failed to determine if path '%s' exists", path) + } + return true, nil +} + func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) { + if err = h.checkState(); err != nil { + return nil, err + } + criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] // Check for virtual pod annotation @@ -738,6 +813,10 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * case guestresource.ResourceTypeSCSIDevice: return modifySCSIDevice(ctx, req.RequestType, req.Settings.(*guestresource.SCSIDevice)) case guestresource.ResourceTypeMappedVirtualDisk: + if err := h.checkState(); err != nil { + return err + } + mvd := req.Settings.(*guestresource.LCOWMappedVirtualDisk) // find the actual controller number on the bus and update the incoming request. var cNum uint8 @@ -775,18 +854,30 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * }() } } - return modifyMappedVirtualDisk(ctx, req.RequestType, mvd, h.securityOptions.PolicyEnforcer) + return h.modifyMappedVirtualDisk(ctx, req.RequestType, mvd) case guestresource.ResourceTypeMappedDirectory: - return modifyMappedDirectory(ctx, h.vsock, req.RequestType, req.Settings.(*guestresource.LCOWMappedDirectory), h.securityOptions.PolicyEnforcer) + if err := h.checkState(); err != nil { + return err + } + + return h.modifyMappedDirectory(ctx, h.vsock, req.RequestType, req.Settings.(*guestresource.LCOWMappedDirectory)) case guestresource.ResourceTypeVPMemDevice: - return modifyMappedVPMemDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPMemDevice), h.securityOptions.PolicyEnforcer) + if err := h.checkState(); err != nil { + return err + } + + return h.modifyMappedVPMemDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPMemDevice)) case guestresource.ResourceTypeCombinedLayers: + if err := h.checkState(); err != nil { + return err + } + cl := req.Settings.(*guestresource.LCOWCombinedLayers) // when cl.ScratchPath == "", we mount overlay as read-only, in which case // we don't really care about scratch encryption, since the host already // knows about the layers and the overlayfs. encryptedScratch := cl.ScratchPath != "" && h.hostMounts.IsEncrypted(cl.ScratchPath) - return modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers), encryptedScratch, h.securityOptions.PolicyEnforcer) + return h.modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers), encryptedScratch) case guestresource.ResourceTypeNetwork: return modifyNetwork(ctx, req.RequestType, req.Settings.(*guestresource.LCOWNetworkAdapter)) case guestresource.ResourceTypeVPCIDevice: @@ -1174,19 +1265,19 @@ func modifySCSIDevice( } } -func modifyMappedVirtualDisk( +func (h *Host) modifyMappedVirtualDisk( ctx context.Context, rt guestrequest.RequestType, mvd *guestresource.LCOWMappedVirtualDisk, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { var verityInfo *guestresource.DeviceVerityInfo + securityPolicy := h.securityOptions.PolicyEnforcer if mvd.ReadOnly { // The only time the policy is empty, and we want it to be empty // is when no policy is provided, and we default to open door // policy. In any other case, e.g. explicit open door or any // other rego policy we would like to mount layers with verity. - if len(securityPolicy.EncodedSecurityPolicy()) > 0 { + if h.HasSecurityPolicy() { devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) if err != nil { return err @@ -1200,101 +1291,154 @@ func modifyMappedVirtualDisk( } } } - switch rt { - case guestrequest.RequestTypeAdd: - mountCtx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - if mvd.MountPath != "" { - if mvd.ReadOnly { - var deviceHash string - if verityInfo != nil { - deviceHash = verityInfo.RootDigest - } - err = securityPolicy.EnforceDeviceMountPolicy(ctx, mvd.MountPath, deviceHash) - if err != nil { - return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) + + // For confidential containers, we revert the policy metadata via the + // transaction rollback mechanism on both mount and unmount errors, but if + // we've actually called Unmount and it fails we permanently block further + // device operations by marking the UVM state as inconsistent. + return securityPolicy.WithMetadataRollback(func() error { + switch rt { + case guestrequest.RequestTypeAdd: + mountCtx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + if mvd.MountPath != "" { + if mvd.ReadOnly { + var deviceHash string + if verityInfo != nil { + deviceHash = verityInfo.RootDigest + } + err = securityPolicy.EnforceDeviceMountPolicy(ctx, mvd.MountPath, deviceHash) + if err != nil { + return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) + } + } else { + err = securityPolicy.EnforceRWDeviceMountPolicy(ctx, mvd.MountPath, mvd.Encrypted, mvd.EnsureFilesystem, mvd.Filesystem) + if err != nil { + return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) + } } - } else { - err = securityPolicy.EnforceRWDeviceMountPolicy(ctx, mvd.MountPath, mvd.Encrypted, mvd.EnsureFilesystem, mvd.Filesystem) - if err != nil { - return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) + config := &scsi.Config{ + Encrypted: mvd.Encrypted, + VerityInfo: verityInfo, + EnsureFilesystem: mvd.EnsureFilesystem, + Filesystem: mvd.Filesystem, + BlockDev: mvd.BlockDev, } + // Since we're rolling back the policy metadata on failure, we + // need to ensure that we have reverted all the side effects + // from this failed mount attempt, otherwise the Rego metadata + // is technically still inconsistent with reality. Mount cleans + // up the created directory and dm devices on failure, so we're + // good. + return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, + mvd.ReadOnly, mvd.Options, config) } - config := &scsi.Config{ - Encrypted: mvd.Encrypted, - VerityInfo: verityInfo, - EnsureFilesystem: mvd.EnsureFilesystem, - Filesystem: mvd.Filesystem, - BlockDev: mvd.BlockDev, - } - return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, - mvd.ReadOnly, mvd.Options, config) - } - return nil - case guestrequest.RequestTypeRemove: - if mvd.MountPath != "" { - if mvd.ReadOnly { - if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { - return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) + return nil + case guestrequest.RequestTypeRemove: + if mvd.MountPath != "" { + if mvd.ReadOnly { + if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) + } + } else { + if err = securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) + } } - } else { - if err := securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { - return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) + // Check that the directory actually exists first, and if it + // does not then we just refuse to do anything, without closing + // the dm device or marking the UVM inconsistent. Policy + // metadata is still reverted to reflect the fact that we have + // not done anything. + // + // Note: we should not do this check before calling the policy + // enforcer (which we have done above), otherwise we will + // inadvertently allow the host to find out whether an arbitrary + // path (which may point to sensitive data within a container + // rootfs) exists or not + if h.HasSecurityPolicy() { + exists, err := checkExists(mvd.MountPath) + if err != nil { + return err + } + if !exists { + return errors.Errorf("unmounting scsi device at %s failed: directory does not exist", mvd.MountPath) + } + } + config := &scsi.Config{ + Encrypted: mvd.Encrypted, + VerityInfo: verityInfo, + EnsureFilesystem: mvd.EnsureFilesystem, + Filesystem: mvd.Filesystem, + BlockDev: mvd.BlockDev, + } + err = scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, config) + if err != nil { + h.setUVMInconsistent( + fmt.Errorf("unmounting scsi device at %s failed: %w", mvd.MountPath, err), + ) + return err } } - config := &scsi.Config{ - Encrypted: mvd.Encrypted, - VerityInfo: verityInfo, - EnsureFilesystem: mvd.EnsureFilesystem, - Filesystem: mvd.Filesystem, - BlockDev: mvd.BlockDev, - } - if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, - mvd.MountPath, config); err != nil { - return err - } + return nil + default: + return newInvalidRequestTypeError(rt) } - return nil - default: - return newInvalidRequestTypeError(rt) - } + }) } -func modifyMappedDirectory( +func (h *Host) modifyMappedDirectory( ctx context.Context, vsock transport.Transport, rt guestrequest.RequestType, md *guestresource.LCOWMappedDirectory, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { - switch rt { - case guestrequest.RequestTypeAdd: - err = securityPolicy.EnforcePlan9MountPolicy(ctx, md.MountPath) - if err != nil { - return errors.Wrapf(err, "mounting plan9 device at %s denied by policy", md.MountPath) - } + securityPolicy := h.securityOptions.PolicyEnforcer + // For confidential containers, we revert the policy metadata via the + // transaction rollback mechanism on both mount and unmount errors, but if + // we've actually called Unmount and it fails we permanently block further + // device operations. + return securityPolicy.WithMetadataRollback(func() error { + switch rt { + case guestrequest.RequestTypeAdd: + err = securityPolicy.EnforcePlan9MountPolicy(ctx, md.MountPath) + if err != nil { + return errors.Wrapf(err, "mounting plan9 device at %s denied by policy", md.MountPath) + } - return plan9.Mount(ctx, vsock, md.MountPath, md.ShareName, uint32(md.Port), md.ReadOnly) - case guestrequest.RequestTypeRemove: - err = securityPolicy.EnforcePlan9UnmountPolicy(ctx, md.MountPath) - if err != nil { - return errors.Wrapf(err, "unmounting plan9 device at %s denied by policy", md.MountPath) - } + // Similar to the reasoning in modifyMappedVirtualDisk, since we're + // rolling back the policy metadata, plan9.Mount here must clean up + // everything if it fails, which it does do. + return plan9.Mount(ctx, vsock, md.MountPath, md.ShareName, uint32(md.Port), md.ReadOnly) + case guestrequest.RequestTypeRemove: + err = securityPolicy.EnforcePlan9UnmountPolicy(ctx, md.MountPath) + if err != nil { + return errors.Wrapf(err, "unmounting plan9 device at %s denied by policy", md.MountPath) + } - return storage.UnmountPath(ctx, md.MountPath, true) - default: - return newInvalidRequestTypeError(rt) - } + // Note: storage.UnmountPath is nop if path does not exist. + err = storage.UnmountPath(ctx, md.MountPath, true) + if err != nil { + h.setUVMInconsistent( + fmt.Errorf("unmounting plan9 device at %s failed: %w", md.MountPath, err), + ) + return err + } + return nil + default: + return newInvalidRequestTypeError(rt) + } + }) } -func modifyMappedVPMemDevice(ctx context.Context, +func (h *Host) modifyMappedVPMemDevice(ctx context.Context, rt guestrequest.RequestType, vpd *guestresource.LCOWMappedVPMemDevice, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { var verityInfo *guestresource.DeviceVerityInfo + securityPolicy := h.securityOptions.PolicyEnforcer var deviceHash string - if len(securityPolicy.EncodedSecurityPolicy()) > 0 { + if h.HasSecurityPolicy() { if vpd.MappingInfo != nil { return fmt.Errorf("multi mapping is not supported with verity") } @@ -1304,23 +1448,56 @@ func modifyMappedVPMemDevice(ctx context.Context, } deviceHash = verityInfo.RootDigest } - switch rt { - case guestrequest.RequestTypeAdd: - err = securityPolicy.EnforceDeviceMountPolicy(ctx, vpd.MountPath, deviceHash) - if err != nil { - return errors.Wrapf(err, "mounting pmem device %d onto %s denied by policy", vpd.DeviceNumber, vpd.MountPath) - } - return pmem.Mount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) - case guestrequest.RequestTypeRemove: - if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, vpd.MountPath); err != nil { - return errors.Wrapf(err, "unmounting pmem device from %s denied by policy", vpd.MountPath) - } + // For confidential containers, we revert the policy metadata via the + // transaction rollback mechanism on both mount and unmount errors, but if + // we've actually called Unmount and it fails we permanently block further + // device operations. + return securityPolicy.WithMetadataRollback(func() error { + switch rt { + case guestrequest.RequestTypeAdd: + err = securityPolicy.EnforceDeviceMountPolicy(ctx, vpd.MountPath, deviceHash) + if err != nil { + return errors.Wrapf(err, "mounting pmem device %d onto %s denied by policy", vpd.DeviceNumber, vpd.MountPath) + } - return pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) - default: - return newInvalidRequestTypeError(rt) - } + // Similar to the reasoning in modifyMappedVirtualDisk, since we're + // rolling back the policy metadata, pmem.Mount here must clean up + // everything if it fails, which it does do. + return pmem.Mount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) + case guestrequest.RequestTypeRemove: + if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, vpd.MountPath); err != nil { + return errors.Wrapf(err, "unmounting pmem device from %s denied by policy", vpd.MountPath) + } + + // Check that the directory actually exists first, and if it does not + // then we just refuse to do anything, without closing the dm-linear or + // dm-verity device or marking the UVM inconsistent. + // + // Similar to the reasoning in modifyMappedVirtualDisk, we should not do + // this check before calling the policy enforcer. + if h.HasSecurityPolicy() { + exists, err := checkExists(vpd.MountPath) + if err != nil { + return err + } + if !exists { + return errors.Errorf("unmounting pmem device at %s failed: directory does not exist", vpd.MountPath) + } + } + + err = pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) + if err != nil { + h.setUVMInconsistent( + fmt.Errorf("unmounting pmem device at %s failed: %w", vpd.MountPath, err), + ) + return err + } + return nil + default: + return newInvalidRequestTypeError(rt) + } + }) } func modifyMappedVPCIDevice(ctx context.Context, rt guestrequest.RequestType, vpciDev *guestresource.LCOWMappedVPCIDevice) error { @@ -1332,85 +1509,102 @@ func modifyMappedVPCIDevice(ctx context.Context, rt guestrequest.RequestType, vp } } -func modifyCombinedLayers( +func (h *Host) modifyCombinedLayers( ctx context.Context, rt guestrequest.RequestType, cl *guestresource.LCOWCombinedLayers, scratchEncrypted bool, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { - isConfidential := len(securityPolicy.EncodedSecurityPolicy()) > 0 + securityPolicy := h.securityOptions.PolicyEnforcer containerID := cl.ContainerID - switch rt { - case guestrequest.RequestTypeAdd: - if isConfidential { - if err := checkValidContainerID(containerID, "container"); err != nil { - return err - } + // For confidential containers, we revert the policy metadata via the + // transaction rollback mechanism on both mount and unmount errors, but if + // we've actually called Unmount and it fails we permanently block further + // device operations. + return securityPolicy.WithMetadataRollback(func() error { + switch rt { + case guestrequest.RequestTypeAdd: + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, "container"); err != nil { + return err + } - // We check this regardless of what the policy says, as long as we're in - // confidential mode. This matches with checkContainerSettings called for - // container creation request. - expectedContainerRootfs := path.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath) - if cl.ContainerRootPath != expectedContainerRootfs { - return fmt.Errorf("combined layers target %q does not match expected path %q", - cl.ContainerRootPath, expectedContainerRootfs) - } + // We check this regardless of what the policy says, as long as we're in + // confidential mode. This matches with checkContainerSettings called for + // container creation request. + expectedContainerRootfs := path.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath) + if cl.ContainerRootPath != expectedContainerRootfs { + return fmt.Errorf("combined layers target %q does not match expected path %q", + cl.ContainerRootPath, expectedContainerRootfs) + } - if cl.ScratchPath != "" { - // At this point, we do not know what the sandbox ID would be yet, so we - // have to allow anything reasonable. - scratchDirRegexStr := fmt.Sprintf( - "^%s/%s/%s/%s$", - guestpath.LCOWRootPrefixInUVM, - validContainerIDRegexRaw, - guestpath.ScratchDir, - containerID, - ) - scratchDirRegex := regexp.MustCompile(scratchDirRegexStr) - if !scratchDirRegex.MatchString(cl.ScratchPath) { - return fmt.Errorf("scratch path %q must match regex %q", - cl.ScratchPath, scratchDirRegexStr) + if cl.ScratchPath != "" { + // At this point, we do not know what the sandbox ID would be yet, so we + // have to allow anything reasonable. + scratchDirRegexStr := fmt.Sprintf( + "^%s/%s/%s/%s$", + guestpath.LCOWRootPrefixInUVM, + validContainerIDRegexRaw, + guestpath.ScratchDir, + containerID, + ) + scratchDirRegex := regexp.MustCompile(scratchDirRegexStr) + if !scratchDirRegex.MatchString(cl.ScratchPath) { + return fmt.Errorf("scratch path %q must match regex %q", + cl.ScratchPath, scratchDirRegexStr) + } } } - } - layerPaths := make([]string, len(cl.Layers)) - for i, layer := range cl.Layers { - layerPaths[i] = layer.Path - } + layerPaths := make([]string, len(cl.Layers)) + for i, layer := range cl.Layers { + layerPaths[i] = layer.Path + } - var upperdirPath string - var workdirPath string - readonly := false - if cl.ScratchPath == "" { - // The user did not pass a scratch path. Mount overlay as readonly. - readonly = true - } else { - upperdirPath = filepath.Join(cl.ScratchPath, "upper") - workdirPath = filepath.Join(cl.ScratchPath, "work") + var upperdirPath string + var workdirPath string + readonly := false + if cl.ScratchPath == "" { + // The user did not pass a scratch path. Mount overlay as readonly. + readonly = true + } else { + upperdirPath = filepath.Join(cl.ScratchPath, "upper") + workdirPath = filepath.Join(cl.ScratchPath, "work") - if err := securityPolicy.EnforceScratchMountPolicy(ctx, cl.ScratchPath, scratchEncrypted); err != nil { - return fmt.Errorf("scratch mounting denied by policy: %w", err) + if err := securityPolicy.EnforceScratchMountPolicy(ctx, cl.ScratchPath, scratchEncrypted); err != nil { + return fmt.Errorf("scratch mounting denied by policy: %w", err) + } } - } - if err := securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath); err != nil { - return fmt.Errorf("overlay creation denied by policy: %w", err) - } + if err := securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath); err != nil { + return fmt.Errorf("overlay creation denied by policy: %w", err) + } - return overlay.MountLayer(ctx, layerPaths, upperdirPath, workdirPath, cl.ContainerRootPath, readonly) - case guestrequest.RequestTypeRemove: - // cl.ContainerID is not set on remove requests, but rego checks that we can - // only umount previously mounted targets anyway - if err := securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { - return errors.Wrap(err, "overlay removal denied by policy") - } + // Correctness for policy transaction rollback: + // MountLayer does two things - mkdir, then mount. On mount failure, the + // target directory is cleaned up. Therefore we're clean in terms of + // side effects. + return overlay.MountLayer(ctx, layerPaths, upperdirPath, workdirPath, cl.ContainerRootPath, readonly) + case guestrequest.RequestTypeRemove: + // cl.ContainerID is not set on remove requests, but rego checks that we can + // only umount previously mounted targets anyway + if err = securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { + return errors.Wrap(err, "overlay removal denied by policy") + } - return storage.UnmountPath(ctx, cl.ContainerRootPath, true) - default: - return newInvalidRequestTypeError(rt) - } + // Note: storage.UnmountPath is a no-op if the path does not exist. + err = storage.UnmountPath(ctx, cl.ContainerRootPath, true) + if err != nil { + h.setUVMInconsistent( + fmt.Errorf("unmounting overlay at %s failed: %w", cl.ContainerRootPath, err), + ) + return err + } + return nil + default: + return newInvalidRequestTypeError(rt) + } + }) } func modifyNetwork(ctx context.Context, rt guestrequest.RequestType, na *guestresource.LCOWNetworkAdapter) (err error) { diff --git a/internal/guest/storage/mount.go b/internal/guest/storage/mount.go index 15664dd693..9b7b946b88 100644 --- a/internal/guest/storage/mount.go +++ b/internal/guest/storage/mount.go @@ -127,6 +127,7 @@ func UnmountPath(ctx context.Context, target string, removeTarget bool) (err err if _, err := osStat(target); err != nil { if os.IsNotExist(err) { + log.G(ctx).WithField("target", target).Warnf("UnmountPath called for non-existent path") return nil } return errors.Wrapf(err, "failed to determine if path '%s' exists", target) diff --git a/internal/guest/storage/overlay/overlay.go b/internal/guest/storage/overlay/overlay.go index aa4877508f..84bf8fa529 100644 --- a/internal/guest/storage/overlay/overlay.go +++ b/internal/guest/storage/overlay/overlay.go @@ -56,8 +56,7 @@ func processErrNoSpace(ctx context.Context, path string, err error) { }).WithError(err).Warn("got ENOSPC, gathering diagnostics") } -// MountLayer first enforces the security policy for the container's layer paths -// and then calls Mount to mount the layer paths as an overlayfs. +// MountLayer calls Mount to mount the layer paths as an overlayfs. func MountLayer( ctx context.Context, layerPaths []string, diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index ec62636590..83c586c3eb 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -121,8 +121,9 @@ type Config struct { // Mount creates a mount from the SCSI device on `controller` index `lun` to // `target` // -// `target` will be created. On mount failure the created `target` will be -// automatically cleaned up. +// `target` will be created. On mount failure the created `target`, as well as +// any associated dm-crypt or dm-verify devices will be automatically cleaned +// up. // // If the config has `encrypted` is set to true, the SCSI device will be // encrypted using dm-crypt. @@ -200,7 +201,8 @@ func Mount( var deviceFS string if config.Encrypted { cryptDeviceName := fmt.Sprintf(cryptDeviceFmt, controller, lun, partition) - encryptedSource, err := encryptDevice(spnCtx, source, cryptDeviceName) + var encryptedSource string + encryptedSource, err = encryptDevice(spnCtx, source, cryptDeviceName) if err != nil { // todo (maksiman): add better retry logic, similar to how SCSI device mounts are // retried on unix.ENOENT and unix.ENXIO. The retry should probably be on an @@ -211,6 +213,13 @@ func Mount( } } source = encryptedSource + defer func() { + if err != nil { + if err := cleanupCryptDevice(spnCtx, cryptDeviceName); err != nil { + log.G(spnCtx).WithError(err).WithField("cryptDeviceName", cryptDeviceName).Debug("failed to cleanup dm-crypt device after mount failure") + } + } + }() } else { // Get the filesystem that is already on the device (if any) and use that // as the mountType unless `Filesystem` was given. diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index ebfcf8e382..94992047bd 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -999,6 +999,12 @@ func Test_Mount_EncryptDevice_Mkfs_Error(t *testing.T) { } return expectedDevicePath, nil } + cleanupCryptDevice = func(_ context.Context, dmCryptName string) error { + if dmCryptName != expectedCryptTarget { + t.Fatalf("expected cleanupCryptDevice name %q got %q", expectedCryptTarget, dmCryptName) + } + return nil + } osStat = osStatNoop xfsFormat = func(arg string) error { diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter.go b/internal/regopolicyinterpreter/regopolicyinterpreter.go index 66f62c5114..6e316f9b41 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter.go @@ -63,6 +63,9 @@ type RegoModule struct { type regoMetadata map[string]map[string]interface{} +const metadataRootKey = "metadata" +const metadataOperationsKey = "metadata" + type regoMetadataAction string const ( @@ -81,6 +84,11 @@ type regoMetadataOperation struct { // The result from a policy query type RegoQueryResult map[string]interface{} +// An immutable, saved copy of the metadata state. +type SavedMetadata struct { + metadataRoot regoMetadata +} + // deep copy for an object func copyObject(data map[string]interface{}) (map[string]interface{}, error) { objJSON, err := json.Marshal(data) @@ -113,6 +121,24 @@ func copyValue(value interface{}) (interface{}, error) { return valueCopy, nil } +// deep copy for regoMetadata. +// We cannot use copyObject for this due to the fact that map[string]interface{} +// is a concrete type and a map of it cannot be used as a map of interface{}. +func copyRegoMetadata(value regoMetadata) (regoMetadata, error) { + valueJSON, err := json.Marshal(value) + if err != nil { + return nil, err + } + + var valueCopy regoMetadata + err = json.Unmarshal(valueJSON, &valueCopy) + if err != nil { + return nil, err + } + + return valueCopy, nil +} + // NewRegoPolicyInterpreter creates a new RegoPolicyInterpreter, using the code provided. // inputData is the Rego data which should be used as the initial state // of the interpreter. A deep copy is performed on it such that it will @@ -123,8 +149,8 @@ func NewRegoPolicyInterpreter(code string, inputData map[string]interface{}) (*R return nil, fmt.Errorf("unable to copy the input data: %w", err) } - if _, ok := data["metadata"]; !ok { - data["metadata"] = make(regoMetadata) + if _, ok := data[metadataRootKey]; !ok { + data[metadataRootKey] = make(regoMetadata) } policy := &RegoPolicyInterpreter{ @@ -207,7 +233,7 @@ func (r *RegoPolicyInterpreter) GetMetadata(name string, key string) (interface{ r.dataAndModulesMutex.Lock() defer r.dataAndModulesMutex.Unlock() - metadataRoot, ok := r.data["metadata"].(regoMetadata) + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) if !ok { return nil, errors.New("illegal interpreter state: invalid metadata object type") } @@ -228,6 +254,32 @@ func (r *RegoPolicyInterpreter) GetMetadata(name string, key string) (interface{ } } +// Saves a copy of the internal policy metadata state. +func (r *RegoPolicyInterpreter) SaveMetadata() (s SavedMetadata, err error) { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) + if !ok { + return SavedMetadata{}, errors.New("illegal interpreter state: invalid metadata object type") + } + s.metadataRoot, err = copyRegoMetadata(metadataRoot) + return s, err +} + +// Restores a previously saved metadata state. +func (r *RegoPolicyInterpreter) RestoreMetadata(m SavedMetadata) error { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + + copied, err := copyRegoMetadata(m.metadataRoot) + if err != nil { + return fmt.Errorf("unable to copy metadata: %w", err) + } + r.data[metadataRootKey] = copied + return nil +} + func newRegoMetadataOperation(operation interface{}) (*regoMetadataOperation, error) { var metadataOp regoMetadataOperation @@ -286,7 +338,7 @@ func (r *RegoPolicyInterpreter) UpdateOSType(os string) error { func (r *RegoPolicyInterpreter) updateMetadata(ops []*regoMetadataOperation) error { // dataAndModulesMutex must be held before calling this - metadataRoot, ok := r.data["metadata"].(regoMetadata) + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) if !ok { return errors.New("illegal interpreter state: invalid metadata object type") } @@ -431,7 +483,7 @@ func (r *RegoPolicyInterpreter) logMetadata() { return } - contents, err := json.Marshal(r.data["metadata"]) + contents, err := json.Marshal(r.data[metadataRootKey]) if err != nil { r.metadataLogger.Printf("error marshaling metadata: %v\n", err.Error()) } else { @@ -637,7 +689,7 @@ func (r *RegoPolicyInterpreter) Query(rule string, input map[string]interface{}) r.logResult(rule, resultSet) ops := []*regoMetadataOperation{} - if rawMetadata, ok := resultSet["metadata"]; ok { + if rawMetadata, ok := resultSet[metadataOperationsKey]; ok { metadata, ok := rawMetadata.([]interface{}) if !ok { return nil, errors.New("error loading metadata array: invalid type") @@ -660,7 +712,7 @@ func (r *RegoPolicyInterpreter) Query(rule string, input map[string]interface{}) } for name, value := range resultSet { - if name == "metadata" { + if name == metadataOperationsKey { continue } else { result[name] = value diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter_test.go b/internal/regopolicyinterpreter/regopolicyinterpreter_test.go index b7d86609f7..534b87e892 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter_test.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter_test.go @@ -72,6 +72,39 @@ func Test_copyValue(t *testing.T) { } } +func Test_copyRegoMetadata(t *testing.T) { + f := func(orig testRegoMetadata) bool { + copy, err := copyRegoMetadata(regoMetadata(orig)) + if err != nil { + t.Error(err) + return false + } + + if len(orig) != len(copy) { + t.Errorf("original and copy have different number of objects: %d != %d", len(orig), len(copy)) + return false + } + + for name, origObject := range orig { + if copyObject, ok := copy[name]; ok { + if !assertObjectsEqual(origObject, copyObject) { + t.Errorf("original and copy differ on key %s", name) + return false + } + } else { + t.Errorf("copy missing object %s", name) + return false + } + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 30, Rand: testRand}); err != nil { + t.Errorf("Test_copyRegoMetadata: %v", err) + } +} + //go:embed test.rego var testCode string @@ -364,6 +397,107 @@ func Test_Metadata_Remove(t *testing.T) { } } +func Test_Metadata_SaveRestore(t *testing.T) { + rego, err := setupRego() + if err != nil { + t.Fatal(err) + } + + f := func(pairs1before, pairs1after intPairArray, name1 metadataName, pairs2before, pairs2after intPairArray, name2 metadataName) bool { + if name1 == name2 { + t.Fatalf("generated two identical names: %s", name1) + } + + err := appendAll(rego, pairs1before, name1) + if err != nil { + t.Errorf("error appending pairs1before: %v", err) + return false + } + err = appendAll(rego, pairs2before, name2) + if err != nil { + t.Errorf("error appending pairs2before: %v", err) + return false + } + + saved, err := rego.SaveMetadata() + if err != nil { + t.Errorf("unable to save metadata: %v", err) + return false + } + + beforeSum1 := getExpectedGapFromPairs(pairs1before) + err = computeGap(rego, name1, beforeSum1) + if err != nil { + t.Error(err) + return false + } + + beforeSum2 := getExpectedGapFromPairs(pairs2before) + err = computeGap(rego, name2, beforeSum2) + if err != nil { + t.Error(err) + return false + } + + // computeGap would have cleared the list, so we restore it. + err = rego.RestoreMetadata(saved) + if err != nil { + t.Errorf("unable to restore metadata: %v", err) + return false + } + + err = appendAll(rego, pairs1after, name1) + if err != nil { + t.Errorf("error appending pairs1after: %v", err) + return false + } + + err = appendAll(rego, pairs2after, name2) + if err != nil { + t.Errorf("error appending pairs2after: %v", err) + return false + } + + afterSum1 := beforeSum1 + getExpectedGapFromPairs(pairs1after) + err = computeGap(rego, name1, afterSum1) + if err != nil { + t.Error(err) + return false + } + + afterSum2 := beforeSum2 + getExpectedGapFromPairs(pairs2after) + err = computeGap(rego, name2, afterSum2) + if err != nil { + t.Error(err) + return false + } + + err = rego.RestoreMetadata(saved) + if err != nil { + t.Errorf("unable to restore metadata: %v", err) + return false + } + + err = computeGap(rego, name1, beforeSum1) + if err != nil { + t.Errorf("computeGap failed for name1 after restore: %v", err) + return false + } + + err = computeGap(rego, name2, beforeSum2) + if err != nil { + t.Errorf("computeGap failed for name2 after restore: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 100, Rand: testRand}); err != nil { + t.Errorf("Test_Metadata_SaveRestore: %v", err) + } +} + //go:embed module.rego var moduleCode string @@ -508,6 +642,7 @@ type testValue struct { } type testArray []interface{} type testObject map[string]interface{} +type testRegoMetadata regoMetadata type testValueType int @@ -580,6 +715,16 @@ func (testObject) Generate(r *rand.Rand, _ int) reflect.Value { return reflect.ValueOf(value) } +func (testRegoMetadata) Generate(r *rand.Rand, _ int) reflect.Value { + numObjects := r.Intn(maxNumberOfFields) + metadata := make(testRegoMetadata) + for i := 0; i < numObjects; i++ { + name := uniqueString(r) + metadata[name] = generateObject(r, 0) + } + return reflect.ValueOf(metadata) +} + func getResult(r *RegoPolicyInterpreter, p intPair, rule string) (RegoQueryResult, error) { input := map[string]interface{}{"a": p.a, "b": p.b} result, err := r.Query("data.test."+rule, input) @@ -640,6 +785,27 @@ func appendLists(r *RegoPolicyInterpreter, p intPair, name metadataName) error { return nil } +func appendAll(r *RegoPolicyInterpreter, pairs intPairArray, name metadataName) error { + for _, pair := range pairs { + if err := appendLists(r, pair, name); err != nil { + return fmt.Errorf("error appending pair %v: %w", pair, err) + } + } + return nil +} + +func getExpectedGapFromPairs(pairs intPairArray) int { + expected := 0 + for _, pair := range pairs { + if pair.a >= pair.b { + expected += pair.a - pair.b + } else { + expected += pair.b - pair.a + } + } + return expected +} + func computeGap(r *RegoPolicyInterpreter, name metadataName, expected int) error { input := map[string]interface{}{"name": string(name)} result, err := r.Query("data.test.compute_gap", input) diff --git a/pkg/securitypolicy/rego_utils_test.go b/pkg/securitypolicy/rego_utils_test.go index 2967b5a6b7..6afe7de6cc 100644 --- a/pkg/securitypolicy/rego_utils_test.go +++ b/pkg/securitypolicy/rego_utils_test.go @@ -347,18 +347,25 @@ type regoPlan9MountTestConfig struct { } func mountImageForContainer(policy *regoEnforcer, container *securityPolicyContainer) (string, error) { - ctx := context.Background() containerID := testDataGenerator.uniqueContainerID() + if err := mountImageForContainerWithID(policy, container, containerID); err != nil { + return "", err + } + return containerID, nil +} + +func mountImageForContainerWithID(policy *regoEnforcer, container *securityPolicyContainer, containerID string) error { + ctx := context.Background() layerPaths, err := testDataGenerator.createValidOverlayForContainer(policy, container) if err != nil { - return "", fmt.Errorf("error creating valid overlay: %w", err) + return fmt.Errorf("error creating valid overlay: %w", err) } scratchDisk := getScratchDiskMountTarget(containerID) err = policy.EnforceRWDeviceMountPolicy(ctx, scratchDisk, true, true, "xfs") if err != nil { - return "", fmt.Errorf("error mounting scratch disk: %w", err) + return fmt.Errorf("error mounting scratch disk: %w", err) } overlayTarget := getOverlayMountTarget(containerID) @@ -367,10 +374,10 @@ func mountImageForContainer(policy *regoEnforcer, container *securityPolicyConta err = policy.EnforceOverlayMountPolicy( ctx, containerID, copyStrings(layerPaths), overlayTarget) if err != nil { - return "", fmt.Errorf("error mounting filesystem: %w", err) + return fmt.Errorf("error mounting filesystem: %w", err) } - return containerID, nil + return nil } func buildMountSpecFromMountArray(mounts []mountInternal, sandboxID string, r *rand.Rand) *oci.Spec { @@ -1404,6 +1411,10 @@ func setupRegoCreateContainerTest(gc *generatedConstraints, testContainer *secur return nil, err } + return createTestContainerSpec(gc, containerID, testContainer, privilegedError, policy, defaultMounts, privilegedMounts) +} + +func createTestContainerSpec(gc *generatedConstraints, containerID string, testContainer *securityPolicyContainer, privilegedError bool, policy *regoEnforcer, defaultMounts, privilegedMounts []mountInternal) (*regoContainerTestConfig, error) { envList := buildEnvironmentVariablesFromEnvRules(testContainer.EnvRules, testRand) sandboxID := testDataGenerator.uniqueSandboxID() diff --git a/pkg/securitypolicy/regopolicy_linux_test.go b/pkg/securitypolicy/regopolicy_linux_test.go index 8dd409fccf..236e64d4bd 100644 --- a/pkg/securitypolicy/regopolicy_linux_test.go +++ b/pkg/securitypolicy/regopolicy_linux_test.go @@ -6,6 +6,7 @@ package securitypolicy import ( "context" "encoding/json" + "errors" "fmt" "math/rand" "os" @@ -963,6 +964,113 @@ func Test_Rego_EnforceOverlayMountPolicy_Multiple_Instances_Same_Container(t *te } } +func Test_Rego_EnforceOverlayMountPolicy_MountFail(t *testing.T) { + f := func(gc *generatedConstraints, commitOnEnforcementFailure bool) bool { + securityPolicy := gc.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + tc := selectContainerFromContainerList(gc.containers, testRand) + tid := testDataGenerator.uniqueContainerID() + scratchTarget := getScratchDiskMountTarget(tid) + + errSimulatedFailure := errors.New("simulated failure") + + err = policy.WithMetadataRollback(func() error { + return policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchTarget, true, true, "xfs") + }) + if err != nil { + t.Errorf("failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + + layerToErr := testRand.Intn(len(tc.Layers)) + errLayerPathIndex := len(tc.Layers) - layerToErr - 1 + layerPaths := make([]string, len(tc.Layers)) + for i, layerHash := range tc.Layers { + target := testDataGenerator.uniqueLayerMountTarget() + layerPaths[len(tc.Layers)-i-1] = target + var policyErr error + err = policy.WithMetadataRollback(func() error { + policyErr = policy.EnforceDeviceMountPolicy(gc.ctx, target, layerHash) + if policyErr != nil { + return policyErr + } + if i == layerToErr { + // Simulate a mount failure at this point, which will cause us to rollback. + return errSimulatedFailure + } + return nil + }) + if policyErr != nil { + t.Errorf("failed to EnforceDeviceMountPolicy: %v", policyErr) + return false + } + } + + overlayTarget := getOverlayMountTarget(tid) + var policyErr error + err = policy.WithMetadataRollback(func() error { + policyErr = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPaths, overlayTarget) + if commitOnEnforcementFailure { + return nil + } + return policyErr + }) + if err != nil && commitOnEnforcementFailure { + t.Errorf("Expected WithMetadataRollback to not return an error, but got: %v", err) + return false + } + if !assertDecisionJSONContains(t, policyErr, append(slices.Clone(layerPaths), "no matching containers for overlay")...) { + return false + } + + layerPathsWithoutErr := make([]string, 0) + for i, layerPath := range layerPaths { + if i != errLayerPathIndex { + layerPathsWithoutErr = append(layerPathsWithoutErr, layerPath) + } + } + + err = policy.WithMetadataRollback(func() error { + policyErr = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPathsWithoutErr, overlayTarget) + if commitOnEnforcementFailure { + return nil + } + return policyErr + }) + if err != nil && commitOnEnforcementFailure { + t.Errorf("Expected WithMetadataRollback to not return an error, but got: %v", err) + return false + } + if !assertDecisionJSONContains(t, policyErr, append(slices.Clone(layerPathsWithoutErr), "no matching containers for overlay")...) { + return false + } + + retryTarget := layerPaths[errLayerPathIndex] + err = policy.WithMetadataRollback(func() error { + return policy.EnforceDeviceMountPolicy(gc.ctx, retryTarget, tc.Layers[layerToErr]) + }) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy again after one previous reverted failure: %v", err) + return false + } + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPaths, overlayTarget) + if err != nil { + t.Errorf("failed to EnforceOverlayMountPolicy after one previous reverted failure: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceOverlayMountPolicy_MountFail: %v", err) + } +} + func Test_Rego_EnforceOverlayUnmountPolicy(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupRegoOverlayTest(p, true) @@ -6103,6 +6211,227 @@ func Test_Rego_Enforce_CreateContainer_RequiredEnvMissingHasErrorMessage(t *test } } +func Test_Rego_EnforceCreateContainer_RejectRevertedOverlayMount(t *testing.T) { + f := func(gc *generatedConstraints, commitOnEnforcementFailure bool) bool { + container := selectContainerFromContainerList(gc.containers, testRand) + securityPolicy := gc.toPolicy() + defaultMounts := generateMounts(testRand) + privilegedMounts := generateMounts(testRand) + + policy, err := newRegoPolicy(securityPolicy.marshalRego(), + toOCIMounts(defaultMounts), + toOCIMounts(privilegedMounts), testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + containerID := testDataGenerator.uniqueContainerID() + tc, err := createTestContainerSpec(gc, containerID, container, false, policy, defaultMounts, privilegedMounts) + if err != nil { + t.Fatal(err) + } + + layers, err := testDataGenerator.createValidOverlayForContainer(policy, container) + if err != nil { + t.Errorf("Failed to createValidOverlayForContainer: %v", err) + return false + } + + errSimulatedFailure := errors.New("simulated failure") + + scratchMountTarget := getScratchDiskMountTarget(containerID) + err = policy.WithMetadataRollback(func() error { + return policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchMountTarget, true, true, "xfs") + }) + if err != nil { + t.Errorf("Failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + + overlayTarget := getOverlayMountTarget(containerID) + var policyErr error + err = policy.WithMetadataRollback(func() error { + policyErr = policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layers, overlayTarget) + if policyErr != nil { + return policyErr + } + // Simulate a failure by rolling back the overlay mount. + return errSimulatedFailure + }) + if policyErr != nil { + t.Errorf("Failed to EnforceOverlayMountPolicy: %v", policyErr) + return false + } + + err = policy.WithMetadataRollback(func() error { + _, _, _, policyErr = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if commitOnEnforcementFailure { + return nil + } + return policyErr + }) + if policyErr == nil { + t.Errorf("EnforceCreateContainerPolicy should have failed due to missing (reverted) overlay mount") + return false + } + if err != nil && commitOnEnforcementFailure { + t.Errorf("Expected WithMetadataRollback to not return an error, but got: %v", err) + return false + } + + // "Retry" overlay mount + err = policy.WithMetadataRollback(func() error { + return policy.EnforceOverlayMountPolicy(gc.ctx, tc.containerID, layers, overlayTarget) + }) + if err != nil { + t.Errorf("Failed to EnforceOverlayMountPolicy: %v", err) + return false + } + + err = policy.WithMetadataRollback(func() error { + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + return err + }) + if err != nil { + t.Errorf("Failed to EnforceCreateContainerPolicy after retrying overlay mount: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceCreateContainerPolicy_RejectRevertedOverlayMount: %v", err) + } +} + +func Test_Rego_EnforceCreateContainer_RetryEverything(t *testing.T) { + f := func(gc *generatedConstraints, + newContainerID, failScratchMount, testDenyInvalidContainerCreation bool, + ) bool { + container := selectContainerFromContainerList(gc.containers, testRand) + securityPolicy := gc.toPolicy() + defaultMounts := generateMounts(testRand) + privilegedMounts := generateMounts(testRand) + + policy, err := newRegoPolicy(securityPolicy.marshalRego(), + toOCIMounts(defaultMounts), + toOCIMounts(privilegedMounts), testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + containerID := testDataGenerator.uniqueContainerID() + tc, err := createTestContainerSpec(gc, containerID, container, false, policy, defaultMounts, privilegedMounts) + if err != nil { + t.Fatal(err) + } + + errSimulatedFailure := errors.New("simulated failure") + + scratchMountTarget := getScratchDiskMountTarget(containerID) + var policyErr error + err = policy.WithMetadataRollback(func() error { + policyErr = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchMountTarget, true, true, "xfs") + if policyErr != nil { + return policyErr + } + if failScratchMount { + return errSimulatedFailure + } + return nil + }) + if policyErr != nil { + t.Errorf("Failed to EnforceRWDeviceMountPolicy: %v", policyErr) + return false + } + + succeedLayerPaths := make([]string, 0) + + if !failScratchMount { + // Simulate one of the layers failing to mount, after which the outside + // gives up on this container and starts over. + layerToErr := testRand.Intn(len(container.Layers)) + for i, layerHash := range container.Layers { + target := testDataGenerator.uniqueLayerMountTarget() + err = policy.WithMetadataRollback(func() error { + policyErr = policy.EnforceDeviceMountPolicy(gc.ctx, target, layerHash) + if policyErr != nil { + return policyErr + } + if i == layerToErr { + // Simulate a mount failure at this point, which will cause us to rollback. + return errSimulatedFailure + } + return nil + }) + if policyErr != nil { + t.Errorf("failed to EnforceDeviceMountPolicy: %v", policyErr) + return false + } + if i == layerToErr { + // The simulated mount failure was rolled back, so the outside + // gives up on this container and starts over. + break + } + succeedLayerPaths = append(succeedLayerPaths, target) + } + + for _, layerPath := range succeedLayerPaths { + err = policy.WithMetadataRollback(func() error { + return policy.EnforceDeviceUnmountPolicy(gc.ctx, layerPath) + }) + if err != nil { + t.Errorf("Failed to EnforceDeviceUnmountPolicy: %v", err) + return false + } + } + + err = policy.WithMetadataRollback(func() error { + return policy.EnforceRWDeviceUnmountPolicy(gc.ctx, scratchMountTarget) + }) + if err != nil { + t.Errorf("Failed to EnforceRWDeviceUnmountPolicy: %v", err) + return false + } + } + + if testDenyInvalidContainerCreation { + err = policy.WithMetadataRollback(func() error { + _, _, _, policyErr = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + return policyErr + }) + if policyErr == nil { + t.Errorf("EnforceCreateContainerPolicy should have failed due to missing (reverted) overlay mount") + return false + } + } + + if newContainerID { + tc.containerID = testDataGenerator.uniqueContainerID() + } + + err = mountImageForContainerWithID(policy, container, tc.containerID) + if err != nil { + t.Errorf("Failed to mount image for container after reverting and retrying: %v", err) + return false + } + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err != nil { + t.Errorf("Failed to EnforceCreateContainerPolicy after retrying: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceCreateContainerPolicy_RejectRevertedOverlayMount: %v", err) + } +} + func Test_Rego_ExecInContainerPolicy_RequiredEnvMissingHasErrorMessage(t *testing.T) { constraints := generateConstraints(testRand, 1) container := selectContainerFromContainerList(constraints.containers, testRand) diff --git a/pkg/securitypolicy/securitypolicyenforcer.go b/pkg/securitypolicy/securitypolicyenforcer.go index 2a4edefce1..3f94bae0f5 100644 --- a/pkg/securitypolicy/securitypolicyenforcer.go +++ b/pkg/securitypolicy/securitypolicyenforcer.go @@ -128,6 +128,7 @@ type SecurityPolicyEnforcer interface { GetUserInfo(spec *oci.Process, rootPath string) (IDName, []IDName, string, error) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string, mountedCim []string) (err error) EnforceRegistryChangesPolicy(ctx context.Context, containerID string, registryValues interface{}) error + WithMetadataRollback(fn func() error) error } //nolint:unused @@ -324,6 +325,10 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceRegistryChangesPolicy(ctx context.C return nil } +func (OpenDoorSecurityPolicyEnforcer) WithMetadataRollback(fn func() error) error { + return fn() +} + type ClosedDoorSecurityPolicyEnforcer struct{} var _ SecurityPolicyEnforcer = (*ClosedDoorSecurityPolicyEnforcer)(nil) @@ -452,3 +457,7 @@ func (ClosedDoorSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Co func (ClosedDoorSecurityPolicyEnforcer) EnforceRegistryChangesPolicy(ctx context.Context, containerID string, registryValues interface{}) error { return errors.New("registry changes are denied by policy") } + +func (ClosedDoorSecurityPolicyEnforcer) WithMetadataRollback(fn func() error) error { + return fn() +} diff --git a/pkg/securitypolicy/securitypolicyenforcer_rego.go b/pkg/securitypolicy/securitypolicyenforcer_rego.go index 96c5613dd6..6cd911b5ec 100644 --- a/pkg/securitypolicy/securitypolicyenforcer_rego.go +++ b/pkg/securitypolicy/securitypolicyenforcer_rego.go @@ -12,6 +12,7 @@ import ( "regexp" "slices" "strings" + "sync" "syscall" "github.com/Microsoft/hcsshim/internal/guestpath" @@ -58,6 +59,8 @@ type regoEnforcer struct { maxErrorMessageLength int // OS type osType string + // Mutex to ensure only one transaction is active + transactionLock sync.Mutex } var _ SecurityPolicyEnforcer = (*regoEnforcer)(nil) @@ -1191,3 +1194,29 @@ func (policy *regoEnforcer) EnforceRegistryChangesPolicy(ctx context.Context, co func (policy *regoEnforcer) GetUserInfo(process *oci.Process, rootPath string) (IDName, []IDName, string, error) { return GetAllUserInfo(process, rootPath) } + +// WithMetadataRollback snapshots metadata, runs fn and rolls metadata back if fn +// returns an error. Nested or concurrent transactions are rejected. Returns +// the error from fn if it fails. +func (policy *regoEnforcer) WithMetadataRollback(fn func() error) error { + if !policy.transactionLock.TryLock() { + return errors.New("nested or concurrent policy transactions are not supported") + } + defer policy.transactionLock.Unlock() + + saved, err := policy.rego.SaveMetadata() + if err != nil { + return errors.Wrap(err, "failed to snapshot policy metadata") + } + + err = fn() + if err != nil { + if restoreErr := policy.rego.RestoreMetadata(saved); restoreErr != nil { + panic(fmt.Sprintf("failed to rollback policy metadata: %v (caused by error: %v)", restoreErr, err)) + } + log.G(context.Background()).WithError(err).Warn("rolled back policy metadata due to error") + return err + } + + return nil +}