#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# bolt integration test suite
#
# Copyright © 2017 Red Hat, Inc
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
# Authors:
#       Christian J. Kellner <christian@kellner.me>
#
# pylint: disable=too-many-lines, too-many-statements, too-many-public-methods
# pylint: disable=too-many-instance-attributes


import binascii
import errno
import gc
import os
import shutil
import socket
import sys
import subprocess
import unittest
import uuid
import tempfile
import time

from collections import namedtuple
from contextlib import contextmanager
from functools import reduce
from itertools import chain

try:
    import gi
    from gi.repository import GLib
    from gi.repository import Gio
    gi.require_version('UMockdev', '1.0')
    from gi.repository import UMockdev

    import dbus
    import dbusmock

    import configparser
except ImportError as e:
    sys.stderr.write('Skipping integration test due to missing dependencies: %s\n' % str(e))
    sys.exit(1)


DBUS_NAME = 'org.freedesktop.bolt'
DBUS_PATH = '/org/freedesktop/bolt'
DBUS_IFACE_PREFIX = 'org.freedesktop.bolt1.'
DBUS_IFACE_MANAGER = DBUS_IFACE_PREFIX + 'Manager'
DBUS_IFACE_DEVICE = DBUS_IFACE_PREFIX + 'Device'
DBUS_IFACE_DOMAIN = DBUS_IFACE_PREFIX + 'Domain'
SERVICE_FILE = '/usr/share/dbus-1/system-services/org.freedesktop.bolt.service'


def get_timeout(topic='default'):
    vals = {
        'valgrind': {
            'default': 20,
            'daemon_start': 60
        },
        'default': {
            'default': 3,
            'daemon_start': 5
        }
    }

    valgrind = os.getenv('VALGRIND')
    lut = vals['valgrind' if valgrind is not None else 'default']
    if topic not in lut:
        raise ValueError('invalid topic')
    return lut[topic]


def can_override_dac():
    """Check if we have CAP_DAC_OVERRIDE"""
    with tempfile.TemporaryDirectory() as tmp:
        path = os.path.join(tmp, "access-check")

        with open(path, "w") as f:
            f.write("")

        mode = os.stat(path).st_mode
        os.chmod(path, 0)
        can_write = os.access(path, os.W_OK)
        os.chmod(path, mode & 0o7777)

        return can_write


class Signal:
    def __init__(self, name):
        self.name = name
        self.callbacks = set()
        self._bridge = None

    def connect(self, callback):
        self.callbacks.add(callback)
        if len(self.callbacks) == 1:
            self._bridge_build()

    def disconnect(self, callback):
        self.callbacks.remove(callback)
        if not self.callbacks:
            self._bridge_destory()

    def disconnect_all(self):
        self.callbacks = set()
        self._bridge_destory()

    def emit(self, *args, **kwargs):
        res = [cb(*args, **kwargs) for cb in self.callbacks]
        return any(res)

    def bridge(self, obj, name, callback):
        if self._bridge is not None:
            raise ValueError('already bridged')
        self._bridge = {'object': obj,
                        'name': name}
        if callback is not None:
            self._bridge['filter'] = callback

    def birdge_destroy(self):
        self._bridge = None

    def _bridge_build(self):
        if self._bridge is None:
            return
        b = self._bridge
        signal_id = b['object'].connect(b['name'], self._bridge_signal)
        b['signal_id'] = signal_id

    def _bridge_destory(self):
        if self._bridge is None:
            return
        b = self._bridge
        b['object'].disconnect(b['signal_id'])
        del b['signal_id']

    def _bridge_signal(self, *args, **kwargs):
        if 'filter' in self._bridge:
            res, args, kwargs = self._bridge['filter'](args, kwargs)
            if not res:
                return False
        return self.emit(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        return self.emit(*args, **kwargs)

    def __iadd__(self, callback):
        self.connect(callback)
        return self

    def __isub__(self, callback):
        self.disconnect(callback)
        return self

    @staticmethod
    def enable(klass):
        lst = getattr(klass, 'signals', [])
        methods = [m for m in dir(klass) if not m.startswith('__')]

        def install(l):
            if l is None:
                return l
            if l in methods:
                print('WARNING: signal "%s" will overwrite method' % l, file=sys.stderr)

            def get_signals(self):
                signals = getattr(self, '__signals', None)
                if signals is None:
                    signals = {}
                    setattr(self, '__signals', signals)
                return signals

            def get_signal(self):
                signals = get_signals(self)
                if l not in signals:
                    signals[l] = Signal(l)
                return signals[l]

            def getter(self):
                return get_signal(self)

            def setter(self, _value):
                return get_signal(self)

            p = property(getter, setter)
            setattr(klass, l, p)
            return l

        bases = klass.__bases__
        ps = {s for b in bases for s in getattr(b, 'signals', [])}
        klass.signals = list(ps.union({install(l) for l in lst}))
        return klass


class Recorder:
    Event = namedtuple('Event', ['what', 'name', 'details', 'time'])

    def __init__(self, target):
        self.recording = True
        self.event = Signal('event')
        self.events = []
        self.target = target
        self.target.g_properties_changed += self._on_props_changed
        self.target.g_signal += self._on_signal

    def close(self):
        if not self.recording:
            return None
        self.target.g_properties_changed -= self._on_props_changed
        self.target.g_signal -= self._on_signal
        self.recording = False
        return self.events

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _on_props_changed(self, props):
        now = time.time()
        for k, v in props.items():
            event = self.Event('property', k, v, now)
            self._add_event(event)

    def _on_signal(self, _proxy, _sender, signal, params):
        now = time.time()
        event = self.Event('signal', signal, params, now)
        self._add_event(event)

    def _add_event(self, event):
        self.events.append(event)
        self.event.emit(event)

    @staticmethod
    def event_match(event, target):
        if event.what != target.what or event.name != target.name:
            return False
        if target.details is None:
            return True
        return event.details == target.details

    @staticmethod
    def events_list_has_event(events, target):
        return list(filter(lambda x: Recorder.event_match(x, target), events))

    @staticmethod
    def events_list_contains(events, what, name, details=None):
        target = Recorder.Event(what, name, details, None)
        return list(filter(lambda x: Recorder.event_match(x, target), events))

    def events_filter(self, event):
        return self.events_list_has_event(self.events, event)

    def have_event(self, event):
        return len(self.events_filter(event)) > 0

    def wait_for_events(self, lst, timeout=None):
        loop = GLib.MainLoop()

        def got_event(event):
            for idx, current in enumerate(lst):
                if self.event_match(event, current):
                    del lst[idx]
                    break
            if not lst:
                loop.quit()

        def got_timeout():
            print('WARNING: timeout reached! Want: %s, Have: %s' %
                  (str(lst), str(self.events)), file=sys.stderr)
            loop.quit()

        # check if we did receive the events already
        lst = list(filter(lambda x: not self.have_event(x), lst))
        if not lst:
            return True

        self.event += got_event
        timeout = timeout or get_timeout()
        GLib.timeout_add(timeout*1000, got_timeout)
        loop.run()
        self.event -= got_event
        return len(lst) == 0

    def wait_for_event(self, what, name, details=None):
        event = self.Event(what, name, details, None)
        return self.wait_for_events([event])

    def wait_for_props(self, **kwargs):
        events = [self.Event('property', name, details, None) for name, details in kwargs.items()]
        res = self.wait_for_events(events)
        assert res


@Signal.enable
class ProxyWrapper:
    signals = ['g_properties_changed', 'g_signal']

    def __init__(self, bus, iname, path):
        self._proxy = Gio.DBusProxy.new_sync(bus,
                                             Gio.DBusProxyFlags.DO_NOT_AUTO_START,
                                             None,
                                             DBUS_NAME,
                                             path,
                                             iname,
                                             None)

        def props_changed(args, _kwargs):
            return True, [args[1].unpack()], {}

        self.g_properties_changed.bridge(self._proxy,
                                         'g-properties-changed',
                                         props_changed)

        self.g_signal.bridge(self._proxy, 'g-signal', None)

    def __getattr__(self, name):
        if name.startswith('_'):
            raise AttributeError
        if '_' in name:
            c = name.split('_')
            name = "".join(x.title() for x in c)
        else:
            name = name[0].upper() + name[1:]
        if name in self._proxy.get_cached_property_names():
            value = self._proxy.get_cached_property(name)
            if value is not None:
                return value.unpack()
            return value

        return getattr(self._proxy, name)

    def record(self):
        return Recorder(self)

    @property
    def object_path(self):
        return self._proxy.get_object_path()


class BoltDevice(ProxyWrapper):
    UNKNOWN = -1
    DISCONNECTED = 0
    CONNECTED = 1
    CONNECTING = 2
    AUTHORIZING = 3
    AUTH_ERROR = 4
    AUTHORIZED = 5

    KEY_MISSING = 0
    KEY_HAVE = 1
    KEY_NEW = 2

    NOPCIE = 1 << 0
    SECURE = 1 << 1
    NOKEY = 1 << 2
    BOOT = 1 << 3

    HOST = 'host'
    PERIPHERAL = 'peripheral'

    def __init__(self, bus, path):
        super(BoltDevice, self).__init__(bus,
                                         DBUS_IFACE_DEVICE,
                                         path)

    def _set_property(self, name, value):
        if isinstance(value, str):
            value = GLib.Variant("s", value)
        target = GLib.Variant("(ssv)", (DBUS_IFACE_DEVICE, name, value))
        res = self._proxy.call_sync('org.freedesktop.DBus.Properties.Set',
                                    target,
                                    0,
                                    -1,
                                    None)
        return res is not None

    @property
    def is_connected(self):
        return self.status > self.DISCONNECTED

    @property
    def is_authorized(self):
        return self.status >= self.AUTHORIZED

    def authorize(self, flags=""):
        self.Authorize('(s)', flags)
        return True

    @property
    def status(self):
        res = getattr(self, 'Status')
        mapping = {'unknown': self.UNKNOWN,
                   'disconnected': self.DISCONNECTED,
                   'connecting': self.CONNECTING,
                   'connected': self.CONNECTED,
                   'authorizing': self.AUTHORIZING,
                   'auth-error': self.AUTH_ERROR,
                   'authorized': self.AUTHORIZED}

        return mapping.get(res, self.UNKNOWN)

    @property
    def authflags(self):
        res = getattr(self, 'AuthFlags')
        mapping = {'none': 0,
                   'nopcie': self.NOPCIE,
                   'secure': self.SECURE,
                   'nokey': self.NOKEY,
                   'boot': self.BOOT}

        print(res, file=sys.stderr)
        keys = [x.strip() for x in res.split('|')]
        return reduce(lambda r, x: r | mapping.get(x, 0), keys, 0)

    @property
    def device_type(self):
        return getattr(self, 'type')

    @property
    def key(self):
        res = getattr(self, 'Key')
        mapping = {'missing': self.KEY_MISSING,
                   'have': self.KEY_HAVE,
                   'new': self.KEY_NEW}

        return mapping.get(res, self.MISSING)

    @property
    def label(self):
        res = getattr(self, 'Label')
        if not res:
            return None
        return res

    @label.setter
    def label(self, value):
        return self._set_property('Label', value)

    @property
    def policy(self):
        res = getattr(self, 'Policy')
        return res if res else None

    @policy.setter
    def policy(self, value):
        return self._set_property('Policy', value)

    @property
    def linkspeed(self):
        res = getattr(self, 'LinkSpeed')
        if not res:
            return None
        return res

    @staticmethod
    def gen_object_path(object_id):
        base = os.path.join(DBUS_PATH, "devices")
        return BoltClient.gen_object_path(base, object_id)


class BoltDomain(ProxyWrapper):
    SECURITY_NONE = 'none'
    SECURITY_USER = 'user'
    SECURITY_SECURE = 'secure'
    SECURITY_DPONLY = 'dponly'
    SECURITY_USBONLY = 'usbonly'
    SECURITY_NOPCIE = 'nopcie'

    @classmethod
    def security_levels(cls):
        v = vars(cls)
        return [v[k] for k in v.keys() if k.startswith('SECURITY_')]

    def __init__(self, bus, path):
        super(BoltDomain, self).__init__(bus,
                                         DBUS_IFACE_DOMAIN,
                                         path)

    @property
    def bootacl(self):
        res = getattr(self, 'BootACL')
        if not res:
            return None
        return res

    @bootacl.setter
    def bootacl(self, value):
        value = GLib.Variant("as", value)

        self._proxy.call_sync('org.freedesktop.DBus.Properties.Set',
                              GLib.Variant("(ssv)",
                                           (DBUS_IFACE_DOMAIN,
                                            "BootACL",
                                            value)),
                              0,
                              -1,
                              None)


@Signal.enable
class BoltClient(ProxyWrapper):
    signals = ['device_added', 'device_removed']

    POLICY_DEFAULT = 'default'
    POLICY_MANUAL = 'manual'
    POLICY_AUTO = 'auto'
    POLICY_IOMMU = 'iommu'

    def __init__(self, bus):
        super(BoltClient, self).__init__(bus,
                                         DBUS_IFACE_MANAGER,
                                         DBUS_PATH)
        self._proxy.connect('g-signal', self._on_dbus_signal)

    def _on_dbus_signal(self, _proxy, _sender, signal, params):
        bus = self._proxy.get_connection()
        if signal == 'DeviceAdded':
            self.device_added.emit(BoltDevice(bus, params[0]))
            return True
        if signal == 'DeviceRemoved':
            self.device_removed.emit(params[0])
            return True
        return False

    @property
    def auth_mode(self):
        res = getattr(self, 'AuthMode')
        if not res:
            return None
        return res

    @auth_mode.setter
    def auth_mode(self, value):
        if isinstance(value, str):
            value = GLib.Variant("s", value)

        self._proxy.call_sync('org.freedesktop.DBus.Properties.Set',
                              GLib.Variant("(ssv)",
                                           (DBUS_IFACE_MANAGER,
                                            "AuthMode",
                                            value)),
                              0,
                              -1,
                              None)

    def list_domains(self):
        domains = self.ListDomains()
        if domains is None:
            return None
        bus = self._proxy.get_connection()
        return [BoltDomain(bus, d) for d in domains]

    def domain_by_id(self, uid):
        object_path = self.DomainById("(s)", uid)
        if object_path is None:
            return None
        bus = self._proxy.get_connection()
        return BoltDomain(bus, object_path)

    def list_devices(self):
        devices = self.ListDevices()
        if devices is None:
            return None
        bus = self._proxy.get_connection()
        return [BoltDevice(bus, d) for d in devices]

    def device_by_uid(self, uid):
        object_path = self.DeviceByUid("(s)", uid)
        if object_path is None:
            return None
        bus = self._proxy.get_connection()
        return BoltDevice(bus, object_path)

    def list_peripherals(self):
        devs = self.list_devices()
        return list(filter(lambda d: d.device_type == BoltDevice.PERIPHERAL, devs))

    def enroll(self, uid, policy=POLICY_DEFAULT, flags=""):
        object_path = self.EnrollDevice("(sss)", uid, policy, flags)
        if object_path is None:
            return None
        bus = self._proxy.get_connection()
        return BoltDevice(bus, object_path)

    def forget(self, uid):
        self.ForgetDevice("(s)", uid)
        return True

    @staticmethod
    def gen_object_path(base, object_id):
        oid = None
        if object_id:
            oid = object_id.replace('-', '_')
        if base and oid:
            return os.path.join('/', base, oid)
        if base:
            return os.path.join('/', base)
        if oid:
            return os.path.join('/', oid)
        return '/'


# Mock Device Tree
@Signal.enable
class Device:
    subsystem = "unknown"
    udev_attrs = []
    udev_props = []

    signals = ['device_connected',
               'device_disconnected']

    def __init__(self, name, children):
        self._parent = None
        self.children = [self._adopt(c) for c in children]
        self.udev = None
        self.name = name
        self.syspath = None

    def _adopt(self, device):
        device.parent = self
        return device

    def _get_own(self, items):
        def get_pair(a):
            v = getattr(self, a.lower())
            return [a, str(v) if v is not None else v]
        x = [get_pair(a) for a in items]
        i = chain.from_iterable(filter(lambda x: x[1] is not None, x))
        return list(i)

    def collect(self, predicate):
        children = self.children
        head = [self] if predicate(self) else []
        tail = chain.from_iterable(c.collect(predicate) for c in children)
        return head + list(filter(predicate, tail))

    def first(self, predicate):
        if predicate(self):
            return self
        for c in self.children:
            found = c.first(predicate)
            if found:
                return found
        return None

    @property
    def parent(self):
        return self._parent

    @parent.setter
    def parent(self, value):
        self._parent = value

    @property
    def root(self):
        return self if self.parent is None else self.parent.root

    def connect_tree(self, bed):
        self.connect(bed)
        for c in self.children:
            c.connect_tree(bed)

    def connect(self, bed):
        print('connecting ' + self.name, file=sys.stderr)
        assert self.syspath is None
        attributes = self._get_own(self.udev_attrs)
        properties = self._get_own(self.udev_props)
        sysparent = self.parent and self.parent.syspath
        self.syspath = bed.add_device(self.subsystem,
                                      self.name,
                                      sysparent,
                                      attributes,
                                      properties)
        self.root.device_connected(self)
        self.testbed = bed

    def disconnect(self, bed):
        print('disconnecting ' + self.name, file=sys.stderr)
        for c in self.children:
            c.disconnect(bed)
        bed.uevent(self.syspath, "remove")
        bed.remove_device(self.syspath)
        self.root.device_disconnected(self)
        self.syspath = None
        self.testbed = None


class TbNativeHostInterface(Device):
    subsystem = "pci"
    udev_attrs = ['class',
                  'device',
                  'vendor']

    udev_props = ['DRIVER']

    driver = "thunderbolt"

    def __init__(self, domain, pci_id=0x15d2):
        assert domain
        name = "pci0000:00:0d.%d" % domain.index
        super().__init__(name, [domain])
        setattr(self, 'class', '0x088000')
        self.vendor = '0x8086'
        self.device = '0x%04x' % pci_id


class TbDevice(Device):
    subsystem = "thunderbolt"
    devtype = "thunderbolt_device"

    udev_attrs = ['authorized',
                  'device',
                  'device_name',
                  'generation',
                  'key',
                  'unique_id',
                  'vendor',
                  'vendor_name',
                  'rx_lanes',
                  'rx_speed',
                  'tx_lanes',
                  'tx_speed']

    udev_props = ['DEVTYPE']

    def __init__(self, name, authorized=0, vendor=None, uid=None, children=None, key='\n', gen=None):
        super(TbDevice, self).__init__(name, children or [])
        self.unique_id = uid or str(uuid.uuid4())
        self.device_name = 'Thunderbolt ' + name
        self.device = self._make_id(self.device_name)
        self.vendor_name = vendor or 'GNOME.org'
        self.vendor = self._make_id(self.vendor_name)
        self.authorized = authorized
        self.key = key
        self.generation = gen
        self.rx_lanes = None
        self.rx_speed = None
        self.tx_lanes = None
        self.tx_speed = None

    def disconnect(self, bed):
        super().disconnect(bed)
        self.authorized = 0
        if self.key is not None:
            self.key = "\n"

    @staticmethod
    def _make_id(name):
        return '0x%X' % binascii.crc32(name.encode('utf-8'))

    @property
    def authorized_file(self):
        if self.syspath is None:
            return None
        return os.path.join(self.syspath, 'authorized')

    @property
    def key_file(self):
        if self.syspath is None:
            return None
        return os.path.join(self.syspath, 'key')

    @property
    def bolt_status(self):
        if self.syspath is None:
            return BoltDevice.DISCONNECTED
        if self.authorized == 0:
            return BoltDevice.CONNECTED
        if self.authorized in [1, 2]:
            return BoltDevice.AUTHORIZED
        return BoltDevice.UNKNOWN

    @property
    def bolt_authflags(self):
        flags = 0
        if self.syspath is None:
            return 0

        if self.authorized == 2:
            flags |= BoltDevice.SECURE

        if self.domain.security == TbDomain.SECURITY_SECURE:
            if self.key is None:
                flags |= BoltDevice.NOKEY
        elif self.domain.security in [TbDomain.SECURITY_DPONLY,
                                      TbDomain.SECURITY_USBONLY,
                                      TbDomain.SECURITY_NOPCIE]:
            flags |= BoltDevice.NOPCIE

        return flags

    @property
    def domain(self):
        return self.parent.domain

    @staticmethod
    def is_unauthorized(d):
        return isinstance(d, TbDevice) and d.authorized == 0

    @property
    def bus_path(self):
        return BoltDevice.gen_object_path(self.unique_id)

    @property
    def linkspeed(self):
        return self.rx_lanes, self.rx_speed, self.tx_lanes, self.tx_speed

    @linkspeed.setter
    def linkspeed(self, ls):
        self.rx_lanes = str(ls["rx.lanes"])
        self.rx_speed = "%u.0 Gb/s" % ls["rx.speed"]
        self.tx_lanes = str(ls["tx.lanes"])
        self.tx_speed = "%u.0 Gb/s" % ls["tx.speed"]

        if self.syspath:
            self.testbed.set_attribute(self.syspath, "rx_lanes", self.rx_lanes)
            self.testbed.set_attribute(self.syspath, "rx_speed", self.rx_speed)
            self.testbed.set_attribute(self.syspath, "tx_lanes", self.tx_lanes)
            self.testbed.set_attribute(self.syspath, "tx_speed", self.tx_speed)
            self.testbed.uevent(self.syspath, 'change')

    def reload_auth(self):
        authorized = self.authorized
        key = self.key
        f = self.authorized_file
        with open(self.authorized_file, 'r') as f:
            data = f.read()
            self.authorized = int(data)
        with open(os.path.join(self.syspath, 'key'), 'r') as f:
            self.key = f.read().strip()
        if self.authorized != authorized or self.key != key:
            if self.syspath:
                self.testbed.uevent(self.syspath, 'change')

    def authorize(self, level):
        with open(self.authorized_file, 'w') as f:
            f.write(level)
        self.reload_auth()

    def writekey(self, key):
        with open(self.key_file, 'w') as f:
            f.write(key)


class TbHost(TbDevice):
    def __init__(self, children, name='Laptop', gen=None):
        super(TbHost, self).__init__(name,
                                     authorized=1,
                                     gen=gen,
                                     children=children)

    def connect(self, bed):
        self.authorized = 1
        super(TbHost, self).connect(bed)


class TbDomain(Device):
    subsystem = "thunderbolt"
    devtype = "thunderbolt_domain"

    udev_attrs = ['security', 'boot_acl', 'iommu_dma_protection']
    udev_props = ['DEVTYPE']

    SECURITY_NONE = 'none'
    SECURITY_USER = 'user'
    SECURITY_SECURE = 'secure'
    SECURITY_DPONLY = 'dponly'
    SECURITY_USBONLY = 'usbonly'
    SECURITY_NOPCIE = 'nopcie'

    def __init__(self, security=SECURITY_SECURE, index=0, host=None, acl=None, iommu=None):
        assert host
        assert isinstance(host, TbHost)
        name = 'domain%d' % index
        if host.unique_id is None:
            host.unique_id = '3b7d4bad-4fdf-44ff-8730-ffffdeadbab%d' % index
        super(TbDomain, self).__init__(name, children=[host])
        self.security = security
        self.boot_acl = ','*(acl-1) if isinstance(acl, int) else acl
        self.iommu_dma_protection = iommu
        self.index = index

        self.nhi = TbNativeHostInterface(self)

    def connect(self, bed):
        self.nhi.connect(bed)
        super().connect(bed)

    def disconnect(self, bed):
        if self.syspath:  # prevent recursion via nhi.disconnect()
            super().disconnect(bed)
            self.nhi.disconnect(bed)

    @property
    def unique_id(self):
        return self.host and self.host.unique_id

    @property
    def host(self):
        if not self.children:
            return None
        return self.children[0]

    @property
    def devices(self):
        return self.collect(lambda c: isinstance(c, TbDevice))

    @property
    def peripherals(self):
        return self.collect(lambda c: isinstance(c, TbDevice) and not isinstance(c, TbHost))

    @property
    def domain(self):
        return self

    @property
    def iommu(self):
        return self.iommu_dma_protection

    @staticmethod
    def checkattr(d, k, v):
        return hasattr(d, k) and getattr(d, k) == v

    def find(self, include_domain, **kwargs):
        def finder(d):
            if not include_domain and isinstance(d, TbDomain):
                return False
            return all([self.checkattr(d, k, v) for k, v in kwargs.items()])

        return self.first(finder)


class Parameterize:
    # class is used for namespacing

    @staticmethod
    def _make_test(klass, name, func):
        params = getattr(func, '__expand')

        def test_function(self):
            func(self, **params)
        suffix = '_' + '_'.join(params.keys())
        setattr(klass, name + suffix, test_function)

    @staticmethod
    def enable(klass):
        tests = [(m, getattr(klass, m)) for m in dir(klass) if m.startswith('test_')]
        expandable = filter(lambda f: hasattr(f[1], '__expand'), tests)
        for t in expandable:
            Parameterize._make_test(klass, t[0], t[1])
        map(lambda t: Parameterize._make_test(klass, t[0], t[1]), expandable)
        return klass

    @staticmethod
    def make(**kwargs):
        def decorator(func):
            setattr(func, '__expand', kwargs)
            return func
        return decorator


# Systemd integration tests
class SdNotify:
    def __init__(self, tmpdir):
        assert(tmpdir is not None)
        soflags = socket.SOCK_DGRAM | socket.SOCK_CLOEXEC | socket.SOCK_NONBLOCK
        self.sock = socket.socket(socket.AF_UNIX, soflags)
        self.path = os.path.join(tmpdir, 'bolt_notify_socket')
        self.sock.bind(self.path)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_PASSCRED, 1)
        self.msgs = []

    def read_msg(self):
        rflags = socket.MSG_DONTWAIT | socket.MSG_CMSG_CLOEXEC | socket.MSG_TRUNC
        try:
            msg, _ac, _flags, _addr = self.sock.recvmsg(4096, 0, rflags)
        except OSError as e:
            if e.errno != errno.EINTR and e.errno != errno.EAGAIN:
                raise e
            msg = None

        if msg is None:
            return None
        msg = msg.decode('utf-8')
        self.msgs.append(msg)
        return msg

    def fetch_msgs(self):
        count = 0
        while self.read_msg():
            count += 1
        return count

    def wait_for(self, event, timeout=None):
        timeout = timeout or get_timeout()
        timeout_count = timeout * 10
        timeout_sleep = 0.1

        while timeout_count > 0:

            count = self.fetch_msgs()
            if count:
                matches = [msg for msg in self.msgs if msg.startswith(event)]
                if matches:
                    return matches

            time.sleep(timeout_sleep)
            timeout_count -= 1

        # debug print
        for msg in self.msgs:
            print(msg)

        raise TimeoutError("Timeout waiting for '%s'" % event)

    def close(self):
        self.path = None
        if self.sock is None:
            return
        self.sock.shutdown(socket.SHUT_WR)
        self.sock.close()
        self.sock = None


# Test Suite
@Parameterize.enable
class BoltTest(dbusmock.DBusTestCase):
    @staticmethod
    def path_from_service_file(_sf):
        with open(SERVICE_FILE) as f:
            for line in f:
                if not line.startswith('Exec='):
                    continue
                return line.split('=', 1)[1].strip()
        return None

    @classmethod
    def setUpClass(cls):
        boltd = None
        boltctl = None
        if 'BOLT_BUILD_DIR' in os.environ:
            print('Testing local build')
            build_dir = os.environ['BOLT_BUILD_DIR']
            boltd = os.path.join(build_dir, 'boltd')
            boltctl = os.path.join(build_dir, 'boltctl')
        elif 'UNDER_JHBUILD' in os.environ:
            print('Testing JHBuild version')
            jhbuild_prefix = os.environ['JHBUILD_PREFIX']
            boltd = os.path.join(jhbuild_prefix, 'libexec', 'boltd')
            boltctl = os.path.join(jhbuild_prefix, 'bin', 'boltctl')
        else:
            print('Testing installed system binaries')
            boltd = BoltTest.path_from_service_file(SERVICE_FILE)
            boltctl = shutil.which('boltctl')

        assert boltd is not None, 'failed to find daemon'
        assert os.access(boltctl, os.X_OK), "could not execute @ " + boltctl
        cls.paths = {'daemon': boltd,
                     'boltctl': boltctl}

        cls.test_bus = Gio.TestDBus.new(Gio.TestDBusFlags.NONE)
        cls.test_bus.up()
        try:
            del os.environ['DBUS_SESSION_BUS_ADDRESS']
        except KeyError:
            pass
        os.environ['DBUS_SYSTEM_BUS_ADDRESS'] = cls.test_bus.get_bus_address()
        cls.dbus = Gio.bus_get_sync(Gio.BusType.SYSTEM, None)
        # a well known key that can be used in testing
        cls.key = 'b68bce095a13ac39e9254a88b189a38f240487aa6f78f803390a0cdeceb774d8'

        # monkey patch Gio.IOErrorEnum to have a.quark() method
        # so it behaves like Gio.DBusError
        io_error_quark = GLib.quark_from_static_string('g-io-error-quark')
        setattr(Gio.IOErrorEnum, 'quark', lambda _self: io_error_quark)

    @classmethod
    def tearDownClass(cls):
        cls.test_bus.down()
        dbusmock.DBusTestCase.tearDownClass()

    def setUp(self):
        self.testbed = UMockdev.Testbed.new()
        self.assertTrue(UMockdev.in_mock_environment())
        self.dbpath = tempfile.mkdtemp()
        self.rundir = tempfile.mkdtemp()
        os.makedirs(os.path.join(self.dbpath, 'devices'))
        os.makedirs(os.path.join(self.dbpath, 'domains'))
        os.makedirs(os.path.join(self.dbpath, 'keys'))

        self.client = None
        self.log = None
        self.daemon = None
        self.polkitd = None
        self.valgrind = False
        self.sdnotify = None

    def tearDown(self):
        shutil.rmtree(self.dbpath)
        shutil.rmtree(self.rundir)
        del self.testbed
        self.daemon_stop()
        self.polkitd_stop()
        if self.sdnotify is not None:
            self.sdnotify.close()
            self.sdnotify = None
        # Gross work-around for some flaky behavior where tests will fail
        # with some obscure error in umockdev, but only if all the tests
        # executed in one go. Maybe some cleanup issue. Needs further
        # investigation.
        gc.collect()

    # dbus helper methods
    def get_dbus_property(self, name, interface=DBUS_IFACE_MANAGER):
        proxy = Gio.DBusProxy.new_sync(self.dbus,
                                       Gio.DBusProxyFlags.DO_NOT_AUTO_START,
                                       None,
                                       DBUS_NAME,
                                       DBUS_PATH,
                                       'org.freedesktop.DBus.Properties',
                                       None)
        return proxy.Get('(ss)', interface, name)

    # daemon helper
    def daemon_start(self, sdnotify=False):
        timeout = get_timeout('daemon_start')  # seconds
        env = os.environ.copy()
        env['G_DEBUG'] = 'fatal-criticals'
        env['UMOCKDEV_DIR'] = self.testbed.get_root_dir()
        env['SYSTEMD_DEVICE_VERIFY_SYSFS'] = '0'
        env['STATE_DIRECTORY'] = self.dbpath
        env['RUNTIME_DIRECTORY'] = self.rundir
        if sdnotify:
            self.sdnotify = SdNotify(self.rundir)
            env['NOTIFY_SOCKET'] = self.sdnotify.path
        argv = [self.paths['daemon'], '-v']
        valgrind = os.getenv('VALGRIND')
        if valgrind is not None:
            argv.insert(0, 'valgrind')
            argv.insert(1, '--leak-check=full')
            if os.path.exists(valgrind):
                argv.insert(2, '--suppressions=%s' % valgrind)
            self.valgrind = True
        self.daemon = subprocess.Popen(argv,
                                       env=env,
                                       stdout=self.log,
                                       stderr=subprocess.STDOUT)

        timeout_count = timeout * 10
        timeout_sleep = 0.1
        while timeout_count > 0:
            time.sleep(timeout_sleep)
            timeout_count -= 1
            try:
                self.get_dbus_property('Version')
                break
            except GLib.GError:
                pass
        else:
            timeout_time = timeout * 10 * timeout_sleep
            self.fail('daemon did not start in %d seconds' % timeout_time)

        self.client = BoltClient(self.dbus)
        self.assertEqual(self.daemon.poll(), None, 'daemon crashed')

    def daemon_stop(self):

        if self.daemon:
            try:
                self.daemon.terminate()
            except OSError:
                pass
            self.daemon.wait()

        self.daemon = None
        self.client = None

    def polkitd_start(self):
        self._polkitd, self._polkitd_obj = self.spawn_server_template(
            'polkitd', {}, stdout=subprocess.DEVNULL)
        self.polkitd = dbus.Interface(self._polkitd_obj, dbusmock.MOCK_IFACE)

    def polkitd_stop(self):
        if self.polkitd is None:
            return
        self._polkitd.terminate()
        self._polkitd.wait()
        self.polkitd = None

    def user_config(self, **kwargs):
        cfg = configparser.ConfigParser()
        cfg.optionxform = lambda option: option

        cfg['config'] = {}
        for k, v in kwargs.items():
            cfg['config'][k] = v

        path = os.path.join(self.dbpath, 'boltd.conf')
        with open(path, 'w') as f:
            cfg.write(f)

        with open(path, 'r') as f:
            print(f.read())

    # executing boltctl
    def boltctl(self, *args):
        env = os.environ.copy()
        env['G_MESSAGES_DEBUG'] = "all"
        args = [self.paths['boltctl']] + list(args)

        print('Calling: ' + " ".join(args), file=sys.stderr)
        process = subprocess.Popen(args,
                                   env=env,
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE)
        stdout, stderr = process.communicate()
        ret = process.returncode
        if ret != 0:
            print('OUTPUT:')
            print(stdout, file=sys.stdout)
            print(stderr, file=sys.stderr)
        return stdout, stderr, ret

    # mock tree stuff
    @staticmethod
    def default_mock_tree(acl=None):
        # default mock tree
        mt = TbDomain(host=TbHost([
            TbDevice('Cable1'),
            TbDevice('Cable2'),
            TbDevice('SSD1')
        ]), acl=acl)
        return mt

    @staticmethod
    def simple_mock_tree():
        mt = TbDomain(host=TbHost([
            TbDevice('Dock')
        ]))
        return mt

    def assertGError(self, have, want):
        if hasattr(have, 'exception'):
            have = have.exception

        domain = GLib.quark_to_string(want.quark())
        code = int(want)
        if have.domain == domain and have.code == code:
            return

        msg = "want: [%s (%d)], have [%s]" % (want, code, have)
        raise self.failureException(msg)

    def assertDeviceEqual(self, local, remote):
        self.assertTrue(local and remote)
        self.assertEqual(local.unique_id, remote.uid)
        self.assertEqual(local.device_name, remote.Name)
        self.assertEqual(local.vendor_name, remote.Vendor)

        # if we are "connected"
        if local.syspath is not None:
            self.assertEqual(local.syspath, remote.sysfs_path)
            self.assertTrue(remote.is_connected)

            # remote.parent is also only valid if we are connected
            if local.parent is not None and isinstance(local.parent, TbDevice):
                self.assertEqual(local.parent.unique_id, remote.parent)

        self.assertEqual(local.bolt_status, remote.status)
        self.assertEqual(local.bolt_authflags, remote.authflags)

        if local.generation is not None:
            self.assertEqual(local.generation, remote.generation)
        else:
            self.assertEqual(0, remote.generation)

        return True

    def add_domain_host(self, domain=0, security='secure', uid=None, bootacl=15, iommu=None, nhi=0x15d2):

        if uid is None:
            uid = str(uuid.uuid4())

        nhi = self.testbed.add_device('pci', '0000:00:0d.%d' % domain, None,
                                      ['class', '0x088000',
                                       'vendor', '0x8086',
                                       'device', '0x%04x' % nhi],
                                      ['DRIVER', 'thunderbolt'])

        props = ['security', security]

        if isinstance(bootacl, int):
            bootacl = ',' * (bootacl-1)

        if bootacl is not None:
            props += ['boot_acl', bootacl]

        if iommu is not None:
            props += ['iommu_dma_protection', str(iommu) + '\n']

        dc = self.testbed.add_device('thunderbolt', 'domain%d' % domain, nhi,
                                     props,
                                     ['DEVTYPE', 'thunderbolt_domain'])

        host = self.testbed.add_device('thunderbolt', "%d-0" % domain, dc,
                                       ['device_name', 'Host',
                                        'device', '0x23',
                                        'vendor_name', 'GNOME.org',
                                        'vendor', '0x23',
                                        'authorized', '1',
                                        'unique_id', uid],
                                       ['DEVTYPE', 'thunderbolt_device'])
        return dc, host

    def remove_domain_host(self, domain, host):
        self.testbed.uevent(host, "remove")
        self.testbed.remove_device(host)
        self.testbed.uevent(domain, "remove")
        self.testbed.remove_device(domain)
        nhi = os.path.dirname(domain)
        self.testbed.uevent(nhi, "remove")
        self.testbed.remove_device(nhi)

    def add_device(self, parent, devid, name, vendor, domain=0, authorized=1, key='', boot=None):
        uid = str(uuid.uuid4())
        props = ['device_name', name,
                 'device', '0x23',
                 'vendor_name', vendor,
                 'vendor', '0x23',
                 'authorized', '%d' % authorized,
                 'unique_id', uid]

        if key is not None:
            # The kernel always returns the key with trailing `\n`
            if not key.endswith('\n'):
                key += '\n'
            props += ['key', key]

        if boot is not None:
            props += ['boot', boot]

        d = self.testbed.add_device('thunderbolt',
                                    "%d-%d" % (domain, devid), parent,
                                    props,
                                    ['DEVTYPE', 'thunderbolt_device'])
        return d, uid

    def find_device_by_uid(self, lst, uid):
        x = [x for x in lst if x.uid == uid]
        self.assertEqual(len(x), 1)
        return x[0]

    def store_put_device(self, dev, policy='auto', key=None):
        df = configparser.ConfigParser()
        df.optionxform = lambda option: option

        uid = dev.unique_id
        df['device'] = {
            'name': dev.device_name,
            'vendor': dev.vendor_name,
            'type': 'host' if isinstance(dev, TbHost) else 'peripheral'
        }

        df['user'] = {
            'storetime': int(time.time()),
            'policy': policy
        }

        if dev.generation:
            df['device']['generation'] = dev.generation

        path = os.path.join(self.dbpath, 'devices', uid)
        with open(path, 'w') as f:
            df.write(f)

        if key == 'known':
            key = 'a26d5ad55b011df39ae06cae1fd329babfecac3465fe0a8828d6178f88e59083'
        elif key == 'device':
            key = dev.key

        if key is None:
            return

        path = os.path.join(self.dbpath, 'keys', uid)
        with open(path, 'w') as f:
            f.write(key)

    def store_get_device(self, uid):
        path = os.path.join(self.dbpath, 'devices', uid)
        df = configparser.ConfigParser()
        df.read(path)

        device = df['device']
        name = device['name']
        vendor = device['vendor']
        generation = None
        if 'generation' in device:
            generation = int(device['generation'])

        dtype = device['type']
        if dtype == 'host':
            dev = TbHost([], gen=generation)
        elif dtype == 'peripheral':
            dev = TbDevice(name, vendor=vendor, uid=uid, gen=generation)
        else:
            raise ValueError('Unknown device type')
        return dev

    @contextmanager
    def store_deny_device(self, uid):
        path = os.path.join(self.dbpath, 'devices', uid)
        fd = os.open(path, os.O_RDWR)
        mode = os.fstat(fd).st_mode
        try:
            print('store: denying access to %s' % uid, file=sys.stderr)
            os.fchmod(fd, 0)
            yield fd
        finally:
            print('store: restoring access to %s' % uid, file=sys.stderr)
            os.fchmod(fd, mode & 0o7777)
            os.close(fd)

    def store_has_domain(self, uid):
        path = os.path.join(self.dbpath, 'domains', uid)
        return os.path.exists(path)

    def store_create_domain(self, *, uid=None, acl=None):
        if not uid:
            uid = str(uuid.uuid4())
        if acl:
            bootacl = ",".join(acl)
        else:
            bootacl = ""

        entry = (
            "[domain]\n"
            f"bootacl={bootacl}\n"
        )
        path = os.path.join(self.dbpath, 'domains', uid)
        with open(path, 'w') as f:
            f.write(entry)
        return uid

    def store_create_journal(self, name, uid, entries):
        ts = 123456
        data = ["%s %s %016iX" % (e["uid"], e["op"], ts) for e in entries]
        journal = os.path.join(self.dbpath, name)
        os.makedirs(journal, exist_ok=True)
        path = os.path.join(journal, uid)
        with open(path, 'w') as f:
            f.write("\n".join(data))

    # the actual tests
    def test_basic(self):
        def make_events(name, devices):
            paths = [GLib.Variant("(o)", (dev.bus_path, )) for dev in devices]
            return [Recorder.Event('signal', name, path, None) for path in paths]

        self.daemon_start()
        version = self.client.version
        assert version is not None
        d = self.client.list_devices()
        self.assertEqual(len(d), 0)
        policy = self.client.default_policy
        self.assertIn(policy, [self.client.POLICY_AUTO,
                               self.client.POLICY_MANUAL])

        # connect all device and make sure we get the proper events
        tree = self.default_mock_tree()

        with self.client.record() as tape:
            events = make_events('DeviceAdded', tree.devices)
            tree.connect_tree(self.testbed)
            res = tape.wait_for_events(events)
            self.assertTrue(res)

        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(tree.devices))
        for remote in devices:
            local = tree.find(False, unique_id=remote.uid)
            self.assertDeviceEqual(local, remote)

        # disconnect all devices again (don't check the host, i.e. peripherals only)
        with self.client.record() as tape:
            events = make_events('DeviceRemoved', tree.peripherals)
            tree.disconnect(self.testbed)
            res = tape.wait_for_events(events)
            self.assertTrue(res)

        devices = self.client.list_peripherals()
        self.assertEqual(len(devices), 0)

        # host device should be stored
        remote = self.client.device_by_uid(tree.host.unique_id)
        self.assertEqual(remote.stored, True)
        self.assertEqual(remote.policy, BoltClient.POLICY_MANUAL)
        self.daemon_stop()

    def test_auth_mode(self):
        _, host = self.add_domain_host(security='secure')

        self.daemon_start()
        self.polkitd_start()

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage',
                                 'org.freedesktop.bolt.enroll',
                                 'org.freedesktop.bolt.authorize'])

        client = self.client

        self.assertEqual(client.auth_mode, 'enabled')

        # disable the authorization
        with client.record() as tape:
            client.auth_mode = 'disabled'
            tape.wait_for_props(AuthMode='disabled')
        self.assertEqual(client.auth_mode, 'disabled')

        _, d1_uid = self.add_device(host, 1, "Dock", "GNOME.org", authorized=0, key='', boot='0')
        _, d2_uid = self.add_device(host, 2, "Dock2", "GNOME.org", authorized=0, key=None, boot='0')

        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        d1_remote = self.find_device_by_uid(devices, d1_uid)
        d2_remote = self.find_device_by_uid(devices, d2_uid)

        before = int(time.time())

        remotes = [(d1_remote, d1_uid), (d2_remote, d2_uid)]

        # check we have not automatically authorized devices
        for remote, uid in remotes:
            self.assertEqual(remote.status, BoltDevice.CONNECTED)
            self.assertEqual(remote.stored, False)

        for remote, uid in remotes:
            with self.assertRaises(GLib.GError) as cm:
                client.enroll(uid)
            self.assertGError(cm, Gio.DBusError.ACCESS_DENIED)

        # enable the authorization again now
        with client.record() as tape:
            client.auth_mode = 'enabled'
            tape.wait_for_props(AuthMode='enabled')
        self.assertEqual(client.auth_mode, 'enabled')

        policy = BoltClient.POLICY_DEFAULT
        for remote, uid in remotes:
            with remote.record() as tape:
                client.enroll(uid, policy)
                tape.wait_for_props(Stored=True)
                self.assertEqual(remote.policy, client.default_policy)
            now = int(time.time())
            self.assertEqual(remote.stored, True)
            self.assertTrue(remote.StoreTime > 1)
            self.assertTrue(remote.StoreTime >= before)
            self.assertTrue(remote.StoreTime <= now)

    def test_domain_basic(self):

        def make_uid(domain):
            return '884c6edd-7118-4b21-b186-b02d396ecca%d' % domain

        def make_domain(domain, security):
            name = 'domain%d' % domain
            uid = make_uid(domain)
            dom, host = self.add_domain_host(domain, security=security, uid=uid)
            return name, dom, host

        # check we get a proper event if the security level changes
        self.daemon_start()
        self.assertEqual(self.client.security_level, 'unknown')
        with self.client.record() as tape:
            d, dom, host = make_domain(0, 'user')
            tape.wait_for_props(SecurityLevel='user')
            self.assertEqual(self.client.security_level, 'user')
        self.daemon_stop()
        self.remove_domain_host(dom, host)

        # create more then one domain
        security = 'secure'
        domains = [make_domain(i, security) for i in range(4)]
        self.daemon_start()

        remotes = self.client.list_domains()
        self.assertEqual(len(domains), len(remotes))

        for name, domain, host in domains:
            for remote in remotes:
                self.assertEqual(remote.security_level, security)

                remote = self.client.domain_by_id(name)
                self.assertEqual(remote.id, name)

        with self.assertRaises(GLib.GError):
            self.client.domain_by_id("")

        with self.assertRaises(GLib.GError):
            self.client.domain_by_id("nonexistent")

        for name, domain, host in domains:
            remote = self.client.domain_by_id(name)
            self.assertIsNotNone(remote)

            with remote.record() as tape:
                self.remove_domain_host(domain, host)
                tape.wait_for_props(SysfsPath='')

        for d in range(4):
            remote = self.client.domain_by_id(make_uid(d))
            self.assertIsNotNone(remote)

            with remote.record() as tape:
                name, domain, host = make_domain(d, 'secure')
                tape.wait_for_props(SysfsPath=domain)

        remotes = self.client.list_domains()
        self.assertEqual(len(remotes), 4)

        self.daemon_stop()

    def test_domain_connected(self):
        uid = '884c6edd-7118-4b21-b186-b02d396ecafe'
        entry = (
            "[domain]\n"
            "bootacl=\n"
        )
        path = os.path.join(self.dbpath, 'domains', uid)
        with open(path, 'w') as f:
            f.write(entry)

        self.daemon_start()
        self.assertEqual(self.client.security_level, 'unknown')

        with self.client.record() as tape:
            security = 'secure'
            self.add_domain_host(0, security=security, uid=uid)
            tape.wait_for_props(SecurityLevel=security)
            self.assertEqual(self.client.security_level, security)

        self.daemon_stop()

    def test_domain_bootacl(self):

        ssd1 = TbDevice('SSD1',)
        cable1 = TbDevice('Cable1', children=[ssd1])
        ssd2 = TbDevice('SSD2')
        cable2 = TbDevice('Cable2', children=[ssd2])

        host = TbHost([cable1,
                       cable2])
        tree = TbDomain(security=TbDomain.SECURITY_SECURE,
                        acl=9,
                        host=host)

        tree.connect_tree(self.testbed)

        self.store_put_device(cable1, key=None)
        self.store_put_device(ssd1, key='known')
        self.store_put_device(cable2, key='known')
        self.store_put_device(ssd2, key='known')

        domain_id = host.unique_id

        self.daemon_start()
        self.polkitd_start()
        client = self.client

        domain = client.domain_by_id(domain_id)
        bootacl = domain.bootacl
        print('domain [%s] bootacl: %s' % (domain.uid, bootacl))

        self.assertNotIn(host.unique_id, bootacl)

        devs = tree.peripherals
        for d in devs:
            self.assertIn(d.unique_id, bootacl)

        # try to set the bootacl, without PolKit allowing it
        with self.assertRaises(GLib.GError) as cm:
            domain.bootacl = ['']
        self.assertGError(cm, Gio.DBusError.ACCESS_DENIED)
        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage'])

        # try to set an invalid bootacl
        with self.assertRaises(GLib.GError) as cm:
            domain.bootacl = ['']
        self.assertGError(cm, Gio.IOErrorEnum.INVALID_ARGUMENT)

        # try to set an invalid bootacl, 2nd time
        with self.assertRaises(GLib.GError) as cm:
            domain.bootacl = None
        self.assertGError(cm, Gio.IOErrorEnum.INVALID_ARGUMENT)

        # remove all entries
        with domain.record() as tape:
            domain.bootacl = ['']*len(bootacl)
            tape.wait_for_props(BootACL=None)
            bootacl = domain.bootacl
            print('domain [%s] bootacl: %s' % (domain.uid, bootacl))
            self.assertTrue(all([x == '' for x in bootacl]))

        bootacl_new = ['']*len(bootacl)
        uuids = ["884c6edd-7118-4b21-b186-b02d396ecca0",
                 "58796843-4a8b-4578-921a-acc654984bfb",
                 "0cc8556f-c73b-41c0-990e-f1e317360626"]

        for i, u in enumerate(uuids):
            bootacl_new[i] = u

        with domain.record() as tape:
            domain.bootacl = bootacl_new
            tape.wait_for_props(BootACL=None)
            bootacl = domain.bootacl
            print('domain [%s] bootacl: %s' % (domain.uid, bootacl))
            for entry in uuids:
                self.assertIn(entry, uuids)

        self.daemon_stop()

    def test_domain_integrated_tbt(self):
        # On integrated TBT (like ICL/TGL) the UUID of the controller is not
        # stable, i.e. it changes between reboots. We therefore must not be
        # saving the domain
        def make_events(name, devices):
            paths = [GLib.Variant("(o)", (dev.object_path, )) for dev in devices]
            return [Recorder.Event('signal', name, path, None) for path in paths]

        ice_lake = {
            "security": 'none',
            "bootacl":None,
            "iommu": "1",
            "nhi": 0x8a17,  # ice lake
        }

        dom, host = self.add_domain_host(**ice_lake)

        self.daemon_start()
        client = self.client

        domains = client.list_domains()
        for domain in domains:
            self.assertFalse(self.store_has_domain(domain.uid))

        with self.client.record() as tape:
            events = make_events('DomainRemoved', domains)

            self.remove_domain_host(dom, host)

            res = tape.wait_for_events(events)
            self.assertTrue(res)

        self.daemon_stop()

    def test_domain_cleanup_stale(self):
        stale = []
        for _ in range(4):
            uid = self.store_create_domain()
            stale += [uid]

        uid = self.store_create_domain()
        acls = [{"uid": str(uuid.uuid4), "op": "+"}]
        self.store_create_journal("bootacl", uid, acls)

        self.daemon_start()
        client = self.client

        domains = client.list_domains()
        have = []
        for domain in domains:
            self.assertNotIn(domain.uid, stale)
            have += [domain.uid]

        self.assertIn(uid, have)

        self.daemon_stop()

    def test_signals_on_start(self):
        # Check that we get DeviceAdded signals for un-authorized
        # devices that are not in the database

        client = BoltClient(self.dbus)
        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        with client.record() as tape:
            self.daemon_start()
            res = tape.wait_for_event('signal',
                                      'DeviceAdded',
                                      None)
            self.assertTrue(res)
        self.daemon_stop()

    def test_basic_device_name(self):
        # prepare the basic setup
        _, host = self.add_domain_host()

        # pylint: disable=bad-whitespace
        devs = [
            # name              vendor           label                         notes
            ['GNOME.org Cable', 'GNOME.org',     'GNOME.org Cable'     ''],  # duplicated vendor name
            ['GNOME.org Cable', 'GNOME.org',     'GNOME.org Cable #2'  ''],  # duplicated device
            ['GNOME.org Cable', 'GNOME.org',     'GNOME.org Cable #3'  ''],  # duplicated device, again
            ['⍾ Laptop',        'Evil Corp. ☢', 'Evil Corp. ☢ ⍾ Laptop'],  # utf-8 chars
            ['HP TB3 Dock',     'HP Inc.',       'HP TB3 Dock'         ''],  # cleanup company names
            ['TB Gadget',       'Apple, Inc.',   'Apple TB Gadget'     ''],  # cleanup company names
        ]
        # pylint: enable=bad-whitespace

        devs = [{'name': d[0], 'vendor': d[1], 'label': d[2], 'id': i+1} for i, d in enumerate(devs)]

        for d in devs:
            did, name, vendor = d['id'], d['name'], d['vendor']
            path, uid = self.add_device(host, did, name, vendor)
            d['path'] = path
            d['uid'] = uid

        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(devs) + 1)

        for d in devs:
            remote = self.find_device_by_uid(devices, d['uid'])
            self.assertEqual(remote.name, d['name'])
            self.assertEqual(remote.vendor, d['vendor'])
            self.assertEqual(remote.label, d['label'])

        self.daemon_stop()

    def test_basic_user_config(self):
        self.user_config(DefaultPolicy='manual')

        self.daemon_start()
        policy = self.client.default_policy
        self.assertEqual(policy, self.client.POLICY_MANUAL)
        self.daemon_stop()

    def test_device_by_uid(self):
        self.daemon_start()

        with self.assertRaises(GLib.GError):
            self.client.device_by_uid("")

        with self.assertRaises(GLib.GError):
            self.client.device_by_uid("nonexistent")

        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        for d in tree.devices:
            remote = self.client.device_by_uid(d.unique_id)
            self.assertIsNotNone(remote)
            self.assertDeviceEqual(d, remote)

        self.daemon_stop()

    def test_device_authflags(self):
        key = self.key

        dc, host = self.add_domain_host(security='dponly')

        d1, d1_uid = self.add_device(host, 1, "Dock", "GNOME.org", authorized=1, key=None)
        d2, d2_uid = self.add_device(host, 2, "Dock2", "GNOME.org", authorized=1, key=None)

        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        device = self.find_device_by_uid(devices, d1_uid)
        self.assertEqual(device.authflags, BoltDevice.NOPCIE)
        device = self.find_device_by_uid(devices, d2_uid)
        self.assertEqual(device.authflags, BoltDevice.NOPCIE)
        self.daemon_stop()

        self.testbed.set_attribute(dc, 'security', 'usbonly')
        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        device = self.find_device_by_uid(devices, d1_uid)
        self.assertEqual(device.authflags, BoltDevice.NOPCIE)
        device = self.find_device_by_uid(devices, d2_uid)
        self.assertEqual(device.authflags, BoltDevice.NOPCIE)
        self.daemon_stop()

        # secure mode
        self.testbed.remove_device(d1)
        self.testbed.remove_device(d2)
        self.testbed.set_attribute(dc, 'security', 'secure')

        d1, d1_uid = self.add_device(host, 1, "Dock", "GNOME.org", authorized=2, key=key, boot='0')
        d2, d2_uid = self.add_device(host, 2, "Dock2", "GNOME.org", authorized=1, key=None, boot='1')

        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        device = self.find_device_by_uid(devices, d1_uid)
        flags = BoltDevice.SECURE
        self.assertEqual(device.authflags, flags)

        device = self.find_device_by_uid(devices, d2_uid)
        flags = BoltDevice.NOKEY | BoltDevice.BOOT
        self.assertEqual(device.authflags, flags)

        self.daemon_stop()

    def test_device_authorize(self):
        self.daemon_start()
        client = self.client
        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        self.polkitd_start()

        to_authorize = tree.collect(TbDevice.is_unauthorized)

        # check that we are not allowed to authorize devices
        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            with self.assertRaises(GLib.GError) as cm:
                remote.authorize()
            self.assertGError(cm, Gio.DBusError.ACCESS_DENIED)

        self.polkitd.SetAllowed(['org.freedesktop.bolt.authorize'])
        before = int(time.time())
        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            tape = remote.record()
            remote.authorize()
            d.reload_auth()  # will emit the uevent, so the daemon can update
            tape.wait_for_props(Status='authorized')
            self.assertDeviceEqual(d, remote)
            # make sure AuthorizeTime is correct
            now = int(time.time())
            self.assertTrue(remote.AuthorizeTime > 1)
            self.assertTrue(remote.AuthorizeTime >= before)
            self.assertTrue(remote.AuthorizeTime <= now)
            tape.close()

        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            with self.assertRaises(GLib.GError) as cm:
                remote.authorize()

        # enroll the device, disconnect it and then try
        # try to authorize it, verifying we are rejecting
        # the call properly
        self.polkitd.SetAllowed(['org.freedesktop.bolt.authorize',
                                 'org.freedesktop.bolt.enroll'])
        d = to_authorize[0]
        d.writekey('a26d5ad55b011df39ae06cae1fd329babfecac3465fe0a8828d6178f88e59083')
        remote = self.client.device_by_uid(d.unique_id)
        with remote.record() as tape:
            client.enroll(d.unique_id, 'manual')
            tape.wait_for_props(Stored=True)
            d.disconnect(self.testbed)
            tape.wait_for_props(Status='disconnected')
            err = None
            with self.assertRaises(GLib.GError) as cm:
                remote.authorize()
            err = cm.exception
            self.assertIn('bolt.Error.BadState', err.message)
        self.daemon_stop()

    def test_device_auto_auth_sl2(self):
        ssd1 = TbDevice('SSD1',)
        cable1 = TbDevice('Cable1', children=[ssd1])
        ssd2 = TbDevice('SSD2')
        cable2 = TbDevice('Cable2', children=[ssd2])
        dock1 = TbDevice('Dock1')
        cable3 = TbDevice('Cable3', children=[dock1])
        dock2 = TbDevice('Dock2')
        cable4 = TbDevice('Cable4', children=[dock2])
        tree = TbDomain(security=TbDomain.SECURITY_SECURE,
                        host=TbHost([
                            cable1,
                            cable2,
                            cable3,
                            cable4,
                        ]))
        tree.connect_tree(self.testbed)

        self.store_put_device(cable1, key=None)
        self.store_put_device(ssd1, key='known')
        self.store_put_device(cable2, key='known')
        self.store_put_device(ssd2, key='known')
        self.store_put_device(cable3, key='known', policy='manual')
        self.store_put_device(dock1, key='known')
        self.store_put_device(cable4, key='known', policy='iommu')
        self.store_put_device(dock2, key=None, policy='iommu')

        self.daemon_start()

        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(tree.devices))

        remote_ssd2 = self.find_device_by_uid(devices, ssd2.unique_id)
        self.assertIsNotNone(remote_ssd2)

        tape = remote_ssd2.record()
        with remote_ssd2.record() as tape:
            if remote_ssd2.status != BoltDevice.AUTHORIZED:
                tape.wait_for_props(Status='authorized')

        self.assertEqual(remote_ssd2.status, BoltDevice.AUTHORIZED)

        # cable1 does *NOT* have a key but we are in SECURE mode
        # so it should not be authorized
        remote_c1 = self.find_device_by_uid(devices, cable1.unique_id)
        self.assertEqual(remote_c1.status, BoltDevice.CONNECTED)

        remote_ssd1 = self.find_device_by_uid(devices, ssd1.unique_id)
        self.assertEqual(remote_ssd1.status, BoltDevice.CONNECTED)

        # cable3 has MANUAL policy and therefore should *NOT* be authorized
        remote_c3 = self.find_device_by_uid(devices, cable3.unique_id)
        self.assertEqual(remote_c3.status, BoltDevice.CONNECTED)

        remote_dock1 = self.find_device_by_uid(devices, dock1.unique_id)
        self.assertEqual(remote_dock1.status, BoltDevice.CONNECTED)

        # cable4 has IOMMU policy and therefore should *NOT* be authorized
        remote_c3 = self.find_device_by_uid(devices, cable3.unique_id)
        self.assertEqual(remote_c3.status, BoltDevice.CONNECTED)

        remote_dock1 = self.find_device_by_uid(devices, dock1.unique_id)
        self.assertEqual(remote_dock1.status, BoltDevice.CONNECTED)

        # now we pretend the user has authorized the device manually,
        # to check if boltd picks up udev changes properly and then
        # auto-authorizes also SSD1

        # we start to tape recorder for ssd1 too, so we don't miss its events
        tape_ssd1 = remote_ssd1.record()
        with remote_c1.record() as tape:
            cable1.authorize('1')
            tape.wait_for_props(Status='authorized')
        self.assertEqual(remote_c1.status, BoltDevice.AUTHORIZED)

        tape_ssd1.wait_for_props(Status='authorized')
        events = tape_ssd1.close()
        self.assertTrue(Recorder.events_list_contains(events, 'property', 'Status', 'authorizing'))
        self.assertEqual(remote_ssd1.status, BoltDevice.AUTHORIZED)
        self.daemon_stop()

    def test_device_auto_auth_iommu(self):
        ssd1 = TbDevice('SSD1',)
        cable1 = TbDevice('Cable1', children=[ssd1])
        ssd2 = TbDevice('SSD2')
        cable2 = TbDevice('Cable2', children=[ssd2])

        tree = TbDomain(security=TbDomain.SECURITY_SECURE,
                        iommu='1',
                        host=TbHost([
                            cable1,
                            cable2,
                        ]))
        tree.connect_tree(self.testbed)

        self.store_put_device(cable1, key=None, policy='iommu')
        self.store_put_device(ssd1, key='known', policy='iommu')
        self.store_put_device(cable2, key='known', policy='iommu')
        self.store_put_device(ssd2, key='known', policy='auto')

        self.daemon_start()

        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(tree.devices))

        remote_ssd2 = self.find_device_by_uid(devices, ssd2.unique_id)
        self.assertIsNotNone(remote_ssd2)

        tape = remote_ssd2.record()
        with remote_ssd2.record() as tape:
            if remote_ssd2.status != BoltDevice.AUTHORIZED:
                tape.wait_for_props(Status='authorized')

        self.assertEqual(remote_ssd2.status, BoltDevice.AUTHORIZED)

        # cable1 does *NOT* have a key but we are in SECURE mode
        # so it should not be authorized even though IOMMU is one
        remote_c1 = self.find_device_by_uid(devices, cable1.unique_id)
        self.assertEqual(remote_c1.status, BoltDevice.CONNECTED)

        remote_ssd1 = self.find_device_by_uid(devices, ssd1.unique_id)
        self.assertEqual(remote_ssd1.status, BoltDevice.CONNECTED)

        self.daemon_stop()

    def device_import_test(self, security, devs, iommu):
        _, host = self.add_domain_host(security=security, iommu=iommu)

        for i, d in enumerate(devs):
            did = i + 1
            path, uid = self.add_device(host,
                                        did,
                                        "Dock%d" % did,
                                        "GNOME.org",
                                        authorized=d['authorized'],
                                        key=d['key'],
                                        boot='%d' % d['boot'])
            d['path'] = path
            d['uid'] = uid

        self.daemon_start()
        self.polkitd_start()
        client = self.client

        devices = client.list_devices()
        self.assertEqual(len(devices), len(devs) + 1)

        for d in devs:
            remote = self.find_device_by_uid(devices, d['uid'])

            if d['authorized'] > 0:
                status = BoltDevice.AUTHORIZED
            else:
                status = BoltDevice.CONNECTED

            self.assertEqual(remote.status, status)

            policy = d['policy']

            if policy is not None:
                stored = True
                self.assertEqual(remote.policy, policy)
            else:
                stored = False
            self.assertEqual(remote.stored, stored)

        self.daemon_stop()

    @Parameterize.make(iommu='1')
    def test_device_import_sl0(self, iommu=None):
        # no security, kinda game over, because the device already has
        # full access now

        devs = [
            {'authorized': 1, 'key': None, 'boot': 0, 'policy': 'iommu'},
            {'authorized': 1, 'key': None, 'boot': 1, 'policy': 'iommu'},
        ]

        self.device_import_test('none', devs, iommu)

    @Parameterize.make(iommu='1')
    def test_device_import_sl1(self, iommu=None):

        policy = 'iommu' if iommu else 'auto'

        devs = [
            # no boot flag -> not importing
            {'authorized': 1, 'key': None, 'boot': 0, 'policy': None},
            # boot flag, and user mode -> importing
            {'authorized': 1, 'key': None, 'boot': 1, 'policy': policy},
        ]

        if not iommu:
            # pylint: disable=bad-whitespace
            devs += [
                # check we are not auto-importing un-authorized devices
                {'authorized': 0, 'key': None, 'boot': 0, 'policy': None},
                {'authorized': 0, 'key': '',   'boot': 0, 'policy': None},
            ]
            # pylint: enable=bad-whitespace

        self.device_import_test('user', devs, iommu)

    @Parameterize.make(iommu='1')
    def test_device_import_sl2(self, iommu=None):
        # like test_device_auto_import but in SECURE mode
        key = self.key

        # pylint: disable=bad-whitespace
        devs = [
            # no boot flag, not importing
            {'authorized': 1, 'key': None, 'boot': 0, 'policy': None},
            {'authorized': 2, 'key': key,  'boot': 0, 'policy': None},
            # boot flag present but no valid key, not importing
            {'authorized': 1, 'key': None, 'boot': 1, 'policy': 'iommu'},
            {'authorized': 1, 'key': '',   'boot': 1, 'policy': 'iommu'},
            # boot, valid key, should not be possible to observe "in the wild"
            {'authorized': 2, 'key': key,  'boot': 1, 'policy': 'iommu'},
        ]

        if not iommu:
            devs += [
                # check we are not auto-importing un-authorized devices
                {'authorized': 0, 'key': None, 'boot': 0, 'policy': None},
                {'authorized': 0, 'key': '',   'boot': 0, 'policy': None},
            ]
         # pylint: enable=bad-whitespace

        self.device_import_test('secure', devs, iommu)

    @Parameterize.make(iommu='1')
    def test_device_import_sl3(self, iommu=None):
        # auto import test for dponly mode, i.e no pci express tunnels
        # in that mode all devices will always be marked as authorized

        devs = [
            {'authorized': 1, 'key': None, 'boot': 0, 'policy': None},
            {'authorized': 1, 'key': None, 'boot': 1, 'policy': None},
        ]

        self.device_import_test('dponly', devs, iommu)

    def test_device_auto_enroll(self):
        # the negative tests, i.e. that devices are not auto-enrolled when iommu is
        # disabled is done via test_basic and test_device_auto_import_{sl1, sl2}

        tree = TbDomain(host=TbHost([
            TbDevice('Cable1', children=[
                TbDevice('Dock1')
            ]),
            TbDevice('Cable2', children=[
                TbDevice('SSD1', key=None)
            ]),
        ]), acl=8, iommu='1')

        tree.connect_tree(self.testbed)

        devs = tree.collect(TbDevice.is_unauthorized)

        self.daemon_start()
        self.polkitd_start()
        client = self.client

        devices = client.list_devices()
        self.assertEqual(len(devices), len(devs) + 1)

        # devices closer to the root first
        for d in sorted(devs, key=lambda d: len(d.syspath)):
            remote = self.find_device_by_uid(devices, d.unique_id)
            self.assertIsNotNone(remote)

            with remote.record() as tape:
                if remote.status != BoltDevice.AUTHORIZED:
                    tape.wait_for_props(Stored=True)

            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, True)

        self.daemon_stop()

    def test_device_enroll(self):
        self.daemon_start()
        tree = self.default_mock_tree(acl=16)
        tree.connect_tree(self.testbed)
        self.polkitd_start()

        client = self.client
        domain = client.domain_by_id(tree.unique_id)
        self.assertIsNotNone(domain)
        self.assertIsNotNone(domain.bootacl)

        to_enroll = tree.collect(TbDevice.is_unauthorized)

        self.assertNotEqual(to_enroll, 0)

        # check that we are not allowed to enroll devices, i.e. the correct
        # policykit action is called.

        for d in to_enroll:
            with self.assertRaises(GLib.GError) as cm:
                client.enroll(d.unique_id)
            self.assertGError(cm, Gio.DBusError.ACCESS_DENIED)

        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        # check we get a proper error for a unknown device
        with self.assertRaises(GLib.GError) as cm:
            # non-existent uuid
            client.forget("884c6edd-7118-4b21-b186-b02d396ecca0")

        before = int(time.time())
        policies = {}
        for i, d in enumerate(to_enroll):
            if i % 2 == 0:
                policy = BoltClient.POLICY_AUTO
                with domain.record() as tape:
                    remote = client.enroll(d.unique_id, policy)
                    tape.wait_for_props(BootACL=None)
            else:
                policy = BoltClient.POLICY_MANUAL
                remote = client.enroll(d.unique_id, policy)

            d.reload_auth()  # will emit the uevent, so the daemon can update
            # the security level for the domain is SECURE, which means we should
            # have authorized via a new key:
            #  status should be AUTHORIZED
            #  stored should be True
            #  key(state) should be KEY_NEW
            #  authflags should be 'secure'
            self.assertDeviceEqual(d, remote)
            self.assertTrue(remote.stored, True)
            self.assertEqual(remote.key, BoltDevice.KEY_NEW)
            self.assertEqual(remote.policy, policy)
            # check the StoreTime is correct
            now = int(time.time())
            self.assertEqual(remote.stored, True)
            self.assertTrue(remote.StoreTime > 1)
            self.assertTrue(remote.StoreTime >= before)
            self.assertTrue(remote.StoreTime <= now)
            self.assertTrue(remote.AuthorizeTime > 1)
            self.assertTrue(remote.AuthorizeTime >= before)
            self.assertTrue(remote.AuthorizeTime <= now)

            if policy == BoltClient.POLICY_AUTO:
                self.assertIn(d.unique_id, domain.bootacl)
            else:
                self.assertNotIn(d.unique_id, domain.bootacl)
            policies[d.unique_id] = policy

        # we disconnect the tree, but since the devices are connected
        # the daemon should have them in its database now
        tree.disconnect(self.testbed)

        expected_number = len(tree.peripherals)
        devices = self.client.list_peripherals()
        tries = 0
        while expected_number != len(devices) and tries < 3:
            time.sleep(.2)
            tries += 1
            devices = self.client.list_peripherals()
        self.assertEqual(len(devices), expected_number)

        tree.connect(self.testbed)              # we connect the domain again
        tree.children[0].connect(self.testbed)  # and the host too

        for i, remote in enumerate(devices):
            policy = policies[remote.uid]
            local = tree.find(False, unique_id=remote.uid)
            self.assertDeviceEqual(local, remote)
            self.assertTrue(remote.stored, True)
            # key status should have changed to HAVE from NEW
            self.assertEqual(remote.key, BoltDevice.KEY_HAVE)
            self.assertEqual(remote.policy, policy)

            # now we connect that specific device and wait for
            # the property changes
            with remote.record() as tape:
                local.connect(self.testbed)
                if policy == BoltClient.POLICY_AUTO:
                    status = 'authorized'
                else:
                    status = 'connected'
                tape.wait_for_props(Status=status)
                local.reload_auth()  # will emit the uevent, so the daemon can update
                self.assertDeviceEqual(local, remote)

    def test_enroll_authorized(self):
        key = self.key

        _, host = self.add_domain_host()
        _, d1_uid = self.add_device(host, 1, "Dock", "GNOME.org", authorized=2, key=key, boot='0')
        _, d2_uid = self.add_device(host, 2, "Dock2", "GNOME.org", authorized=1, key=None, boot='0')

        self.daemon_start()
        self.polkitd_start()
        client = self.client

        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        d1_remote = self.find_device_by_uid(devices, d1_uid)
        d2_remote = self.find_device_by_uid(devices, d2_uid)

        before = int(time.time())

        for remote in [d1_remote, d2_remote]:
            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, False)

        policy = BoltClient.POLICY_DEFAULT
        for remote, uid in [(d1_remote, d1_uid), (d2_remote, d2_uid)]:

            with remote.record() as tape:
                client.enroll(uid, policy)
                tape.wait_for_props(Stored=True)
            self.assertEqual(remote.policy, client.default_policy)
            now = int(time.time())
            self.assertEqual(remote.stored, True)
            self.assertTrue(remote.StoreTime > 1)
            self.assertTrue(remote.StoreTime >= before)
            self.assertTrue(remote.StoreTime <= now)

        self.daemon_stop()

    def test_enroll_iommu(self):
        _, host = self.add_domain_host(security='secure', iommu='1')

        self.daemon_start()
        self.polkitd_start()
        client = self.client

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage',
                                 'org.freedesktop.bolt.enroll',
                                 'org.freedesktop.bolt.authorize'])

        # disable the authorization globally
        with client.record() as tape:
            client.auth_mode = 'disabled'
            tape.wait_for_props(AuthMode='disabled')
        self.assertEqual(client.auth_mode, 'disabled')

        _, d1_uid = self.add_device(host, 1, "Dock", "GNOME.org", authorized=0, key='', boot='0')
        _, d2_uid = self.add_device(host, 2, "Dock2", "GNOME.org", authorized=0, key=None, boot='0')

        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        d1_remote = self.find_device_by_uid(devices, d1_uid)
        d2_remote = self.find_device_by_uid(devices, d2_uid)

        before = int(time.time())

        remotes = [(d1_remote, d1_uid), (d2_remote, d2_uid)]

        # check we have not automatically enrolled/authorized devices
        for remote, uid in remotes:
            self.assertEqual(remote.status, BoltDevice.CONNECTED)
            self.assertEqual(remote.stored, False)

        # enable the authorization again now
        with client.record() as tape:
            client.auth_mode = 'enabled'
            tape.wait_for_props(AuthMode='enabled')
        self.assertEqual(client.auth_mode, 'enabled')

        # with iommu enabled, the policy should be adjusted to 'IOMMU'
        policy = BoltClient.POLICY_DEFAULT
        for remote, uid in remotes:
            with remote.record() as tape:
                client.enroll(uid, policy)
                tape.wait_for_props(Stored=True)
                self.assertEqual(remote.policy, client.POLICY_IOMMU)
            now = int(time.time())
            self.assertEqual(remote.stored, True)
            self.assertTrue(remote.StoreTime > 1)
            self.assertTrue(remote.StoreTime >= before)
            self.assertTrue(remote.StoreTime <= now)

    def test_device_key_upgrade(self):
        dock = TbDevice('Dock',)
        ssd1 = TbDevice('SSD1',)
        ssd2 = TbDevice('SSD2',)
        ssd2.key = None
        tree = TbDomain(security=TbDomain.SECURITY_SECURE,
                        host=TbHost([
                            dock,
                            ssd1,
                            ssd2
                        ]))

        tree.connect_tree(self.testbed)
        self.store_put_device(dock)
        self.store_put_device(ssd2)

        self.daemon_start()
        self.polkitd_start()

        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(tree.devices))
        self.polkitd.SetAllowed(['org.freedesktop.bolt.authorize'])

        remote_dock = self.find_device_by_uid(devices, dock.unique_id)
        remote_ssd1 = self.find_device_by_uid(devices, ssd1.unique_id)
        remote_ssd2 = self.find_device_by_uid(devices, ssd2.unique_id)

        self.assertEqual(remote_dock.key, BoltDevice.KEY_MISSING)
        self.assertEqual(remote_ssd1.key, BoltDevice.KEY_MISSING)
        self.assertEqual(remote_ssd2.key, BoltDevice.KEY_MISSING)

        with remote_dock.record() as tape:
            remote_dock.authorize()
            tape.wait_for_props(Key='new')
        self.assertEqual(remote_dock.key, BoltDevice.KEY_NEW)

        with remote_ssd1.record() as tape:
            remote_ssd1.authorize()
            tape.wait_for_props(Status='authorized')
        self.assertEqual(remote_ssd1.key, BoltDevice.KEY_MISSING)

        with remote_ssd2.record() as tape:
            remote_ssd2.authorize()
            tape.wait_for_props(Status='authorized')
        self.assertEqual(remote_ssd2.key, BoltDevice.KEY_MISSING)

    def test_device_forget(self):
        self.daemon_start()
        tree = self.default_mock_tree(acl=16)
        self.polkitd_start()
        tree.connect_tree(self.testbed)

        client = self.client
        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        to_enroll = tree.collect(TbDevice.is_unauthorized)
        policy = BoltClient.POLICY_AUTO
        for d in to_enroll:
            remote = client.enroll(d.unique_id, policy)
            d.reload_auth()
            self.assertDeviceEqual(d, remote)
            self.assertTrue(remote.stored, True)
            self.assertEqual(remote.key, BoltDevice.KEY_NEW)
            self.assertEqual(remote.policy, policy)

        tree.disconnect(self.testbed)
        expected_number = len(tree.peripherals)
        devices = self.client.list_peripherals()
        tries = 0
        while expected_number != len(devices) and tries < 3:
            time.sleep(.2)
            tries += 1
            devices = self.client.list_peripherals()
        self.assertEqual(len(devices), expected_number)

        for remote in devices:
            with self.assertRaises(GLib.GError) as cm:
                client.forget(remote.uid)
            self.assertGError(cm, Gio.DBusError.ACCESS_DENIED)

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage'])

        # check we get a proper error for a unknown device
        with self.assertRaises(GLib.GError) as cm:
            # non-existent uuid
            client.forget("884c6edd-7118-4b21-b186-b02d396ecca0")

        domain = client.domain_by_id(tree.unique_id)

        # now we actually forget the device
        for remote in devices:
            self.assertIn(remote.uid, domain.bootacl)
            with domain.record() as tape:
                client.forget(remote.uid)
                tape.wait_for_props(BootACL=None)
            self.assertNotIn(remote.uid, domain.bootacl)

        devices = self.client.list_peripherals()
        self.assertEqual(len(devices), 0)

    def test_device_label(self):
        self.daemon_start()
        tree = self.simple_mock_tree()
        self.polkitd_start()
        tree.connect_tree(self.testbed)

        client = self.client
        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        local = tree.collect(TbDevice.is_unauthorized)[0]
        policy = BoltClient.POLICY_AUTO

        remote = client.enroll(local.unique_id, policy)
        local.reload_auth()

        self.assertEqual(remote.label, "%s %s" % (local.vendor_name, local.device_name))

        self.assertDeviceEqual(local, remote)
        self.assertTrue(remote.stored, True)
        self.assertEqual(remote.key, BoltDevice.KEY_NEW)
        self.assertEqual(remote.policy, policy)

        with self.assertRaises(GLib.GError) as cm:
            remote.label = 'not authorized'
        self.assertGError(cm, Gio.DBusError.ACCESS_DENIED)

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage'])
        for val in ['', ' ', '     ']:
            with self.assertRaises(GLib.GError) as cm:
                remote.label = val
            self.assertGError(cm, Gio.DBusError.INVALID_ARGS)

        self.assertEqual(remote.label, "%s %s" % (local.vendor_name, local.device_name))

        # store update failure check is done in test_device_store_failures

        val = 'A valid label'

        with remote.record() as tape:
            remote.label = val
            tape.wait_for_props(Label=val)
        self.assertEqual(remote.label, val)

        self.daemon_stop()

    @unittest.skipIf(can_override_dac(), "have DAC override")
    def test_device_store_failures(self):
        self.daemon_start()
        tree = self.simple_mock_tree()
        self.polkitd_start()
        tree.connect_tree(self.testbed)

        client = self.client
        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll',
                                 'org.freedesktop.bolt.manage'])

        local = tree.collect(TbDevice.is_unauthorized)[0]
        policy = BoltClient.POLICY_AUTO

        remote = client.enroll(local.unique_id, policy)
        local.reload_auth()

        label = remote.label

        # check for store errors and verify the label did not change
        with self.assertRaises(GLib.GError) as cm, \
             self.store_deny_device(local.unique_id):
            remote.label = 'denied'
        self.assertGError(cm, Gio.IOErrorEnum.PERMISSION_DENIED)
        self.assertEqual(remote.label, label)

        # check store update failures
        with self.assertRaises(GLib.GError) as cm, \
             self.store_deny_device(local.unique_id):
            remote.policy = 'iommu'
        self.assertGError(cm, Gio.IOErrorEnum.PERMISSION_DENIED)
        self.assertEqual(remote.policy, policy)

        self.daemon_stop()

    def test_device_policy(self):
        dock = TbDevice('Dock', gen=3)
        tree = TbDomain(security=TbDomain.SECURITY_SECURE,
                        acl=16,
                        host=TbHost([dock]))
        tree.connect_tree(self.testbed)

        # multiple domain controller for bootacl checks
        ssd = TbDevice('SSD', gen=4)
        dom2 = TbDomain(security=TbDomain.SECURITY_SECURE,
                        index=1, acl=16,
                        host=TbHost([ssd], name="Laptop1"))
        dom2.connect_tree(self.testbed)

        self.daemon_start()
        self.polkitd_start()

        client = self.client
        remote = client.device_by_uid(dock.unique_id)

        # we are not allowed to manage the device
        with self.assertRaises(GLib.GError) as cm:
            remote.policy = 'iommu'
        self.assertGError(cm, Gio.DBusError.ACCESS_DENIED)

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage',
                                 'org.freedesktop.bolt.enroll'])

        # device is not stored
        with self.assertRaises(GLib.GError) as cm:
            remote.policy = 'iommu'
        self.assertGError(cm, Gio.DBusError.INVALID_ARGS)

        # enroll the device with manual policy
        policy = BoltClient.POLICY_MANUAL
        with remote.record() as tape:
            remote = client.enroll(dock.unique_id, policy)
            dock.reload_auth()
            tape.wait_for_props(Stored=True)
        self.assertEqual(remote.policy, BoltClient.POLICY_MANUAL)

        # policy is invalid
        with self.assertRaises(GLib.GError) as cm:
            remote.policy = 'foobar'
        self.assertGError(cm, Gio.DBusError.INVALID_ARGS)

        # store update failure check is done in test_device_store_failures

        # finally a valid update
        with remote.record() as tape:
            remote.policy = 'iommu'
            tape.wait_for_props(Policy='iommu')
        self.assertEqual(remote.policy, BoltClient.POLICY_IOMMU)

        # enroll the ssd with manual policy
        remote = client.device_by_uid(ssd.unique_id)
        policy = BoltClient.POLICY_MANUAL
        with remote.record() as tape:
            remote = client.enroll(ssd.unique_id, policy)
            ssd.reload_auth()
            tape.wait_for_props(Stored=True)
        self.assertEqual(remote.policy, BoltClient.POLICY_MANUAL)

        # policy -> auto, check that it got added to the bootacl
        with remote.record() as tape:
            remote.policy = 'auto'
            tape.wait_for_props(Policy='auto')
        self.assertEqual(remote.policy, BoltClient.POLICY_AUTO)

        for dom in [tree, dom2]:
            domain = self.client.domain_by_id(dom.unique_id)
            self.assertIn(ssd.unique_id, domain.bootacl)

        # policy -> manual, check it got removed from the bootacl
        with remote.record() as tape:
            remote.policy = 'auto'
            tape.wait_for_props(Policy='auto')
        self.assertEqual(remote.policy, BoltClient.POLICY_AUTO)

        for dom in [tree, dom2]:
            domain = self.client.domain_by_id(dom.unique_id)
            self.assertIn(ssd.unique_id, domain.bootacl)

    def test_device_generation(self):
        dock = TbDevice('Dock', gen=3)
        ssd1 = TbDevice('SSD1', gen=1)
        ssd2 = TbDevice('SSD2', gen=2)
        ssd2 = TbDevice('Ethernet', gen=None)
        tree = TbDomain(security=TbDomain.SECURITY_SECURE,
                        host=TbHost([
                            dock,
                            ssd1,
                            ssd2
                        ], gen=4))

        tree.connect_tree(self.testbed)
        self.daemon_start()

        for d in tree.devices:
            remote = self.client.device_by_uid(d.unique_id)
            self.assertDeviceEqual(d, remote)
            print("local: %d, remote: %d" % (d.generation or 0, remote.generation))

        self.assertEqual(self.client.generation, 4)

    def test_device_generation_update(self):
        dock = TbDevice('Dock', gen=0)
        host = TbHost([dock], gen=0)
        tree = TbDomain(security=TbDomain.SECURITY_SECURE, host=host)

        self.store_put_device(host)
        self.store_put_device(dock)
        self.daemon_start()

        self.assertEqual(self.client.generation, 0)

        remote = self.client.device_by_uid(host.unique_id)
        self.assertEqual(remote.device_type, BoltDevice.HOST)

        with self.client.record() as tape:
            host.generation = 4
            dock.generation = 3

            tree.connect_tree(self.testbed)

            tape.wait_for_props(Generation=4)

        self.daemon_stop()

        # ensure that the updates are written to the store
        stored = self.store_get_device(host.unique_id)
        self.assertEqual(stored.generation, 4)

        stored = self.store_get_device(dock.unique_id)
        self.assertEqual(stored.generation, 3)

    def test_device_linkspeed(self):
        dock = TbDevice('Dock', gen=1)
        host = TbHost([dock], gen=0)
        tree = TbDomain(security=TbDomain.SECURITY_SECURE, host=host)

        linkspeed = {'rx.speed': 20, 'rx.lanes': 1, 'tx.speed': 10, 'tx.lanes': 2}
        dock.linkspeed = linkspeed

        tree.connect_tree(self.testbed)

        self.daemon_start()
        remote = self.client.device_by_uid(dock.unique_id)

        self.assertEqual(remote.linkspeed, linkspeed)

        with remote.record() as tape:
            linkspeed = {'rx.speed': 20, 'rx.lanes': 2, 'tx.speed': 10, 'tx.lanes': 1}
            dock.linkspeed = linkspeed
            tape.wait_for_props(LinkSpeed=linkspeed)

        self.daemon_stop()

    def test_sdnotify(self):
        self.add_domain_host(security='secure', iommu='1')

        self.daemon_start(sdnotify=True)
        self.polkitd_start()

        assert(self.sdnotify)

        msgs = self.sdnotify.wait_for('STATUS')
        self.assertTrue(msgs is not None)

        self.daemon_stop()

    def test_boltctl(self):
        ssd1 = TbDevice('SSD1',)
        cable1 = TbDevice('Cable1', children=[ssd1])
        ssd2 = TbDevice('SSD2')
        cable2 = TbDevice('Cable2', children=[ssd2])
        ssd3 = TbDevice('SSD3')
        cable3 = TbDevice('Cable3', children=[ssd3])

        host = TbHost([cable1,
                       cable2,
                       cable3])
        tree = TbDomain(security=TbDomain.SECURITY_SECURE,
                        host=host)

        tree.connect_tree(self.testbed)

        self.daemon_start()
        self.polkitd_start()

        out, _, res = self.boltctl('--version')
        out = str(out)
        self.assertEqual(res, 0)
        self.assertNotEqual(len(out), 0)
        self.assertIn('bolt', out)

        out, _, res = self.boltctl('power', '-q')
        out = str(out)
        self.assertEqual(res, 0)
        self.assertNotEqual(len(out), 0)
        self.assertIn('supported', out)

        out, _, res = self.boltctl('domains')
        out = str(out)
        self.assertEqual(res, 0)
        self.assertNotEqual(len(out), 0)
        self.assertIn(tree.unique_id, out)

        self.polkitd.SetAllowed(['org.freedesktop.bolt.authorize',
                                 'org.freedesktop.bolt.enroll',
                                 'org.freedesktop.bolt.manage'])

        to_enroll = [cable1, ssd1]

        out, _, res = self.boltctl('list')
        out = str(out)
        self.assertEqual(res, 0)
        for dev in to_enroll:
            self.assertIn(dev.unique_id, out)

        for dev in to_enroll:
            uid = dev.unique_id
            _, _, res = self.boltctl('authorize', uid)
            self.assertEqual(res, 0)

            remote = self.client.device_by_uid(uid)
            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, False)

        for dev in to_enroll:
            uid = dev.unique_id
            _, _, res = self.boltctl('enroll', uid)
            self.assertEqual(res, 0)

            remote = self.client.device_by_uid(uid)
            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, True)

            out, _, res = self.boltctl('info', uid)
            out = str(out)
            self.assertEqual(res, 0)
            self.assertIn('authorized', out)

        for dev in to_enroll:
            uid = dev.unique_id
            _, _, res = self.boltctl('forget', uid)
            self.assertEqual(res, 0)

            remote = self.client.device_by_uid(uid)
            self.assertEqual(remote.stored, False)

        for dev in to_enroll:
            uid = dev.unique_id
            _, _, res = self.boltctl('enroll', uid)
            self.assertEqual(res, 0)

            remote = self.client.device_by_uid(uid)
            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, True)

        _, _, res = self.boltctl('forget', '--all')
        self.assertEqual(res, 0)
        for dev in to_enroll:
            uid = dev.unique_id
            remote = self.client.device_by_uid(uid)
            self.assertEqual(remote.stored, False)

        # chain enrollment
        _, _, res = self.boltctl('enroll', '--chain', ssd2.unique_id)
        self.assertEqual(res, 0)
        for dev in [cable2, ssd2]:
            uid = dev.unique_id
            remote = self.client.device_by_uid(uid)
            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, True)

        # chain authorization
        _, _, res = self.boltctl('authorize', '--chain', ssd3.unique_id)
        self.assertEqual(res, 0)
        for dev in [cable3, ssd3]:
            uid = dev.unique_id
            remote = self.client.device_by_uid(uid)
            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, False)

        # boltctl config: describe properties
        out, _, res = self.boltctl('config', '--describe')
        out = str(out)
        self.assertEqual(res, 0)
        self.assertNotEqual(len(out), 0)
        self.assertIn('auth-mode', out)
        self.assertIn('domain.bootacl', out)
        self.assertIn('device.label', out)

        out, _, res = self.boltctl('config', 'auth-mode')
        out = str(out)
        self.assertEqual(res, 0)
        self.assertNotEqual(len(out), 0)
        self.assertIn('enabled', out)

        with self.client.record() as tape:
            _, _, res = self.boltctl('config', 'auth-mode', 'disabled')
            self.assertEqual(res, 0)
            tape.wait_for_props(AuthMode='disabled')
        self.assertEqual(self.client.auth_mode, 'disabled')

        # boltctl config: set device name
        target = ssd2.unique_id
        remote = self.client.device_by_uid(target)
        with remote.record() as tape:
            _, _, res = self.boltctl('config', 'device.label', target, 'Nobody')
            self.assertEqual(res, 0)
            tape.wait_for_props(Label='Nobody')
        self.assertEqual(remote.label, 'Nobody')

        # all done
        self.daemon_stop()


def list_tests():
    suit = unittest.defaultTestLoader.loadTestsFromTestCase(BoltTest)
    for t in suit:
        comps = t.id().split('.')
        machine = ".".join(comps[1:])
        human = comps[2][len('test_'):]
        yield machine, human


def main():
    if len(sys.argv) == 2 and sys.argv[1] == "list-tests":
        for machine, human in list_tests():
            print("%s %s" % (machine, human), end="\n")
        sys.exit(0)

    if 'umockdev' not in os.environ.get('LD_PRELOAD', ''):
        wrapped = ['umockdev-wrapper'] + sys.argv
        os.execvp(wrapped[0], wrapped)

    unittest.main(verbosity=2)


if __name__ == '__main__':
    main()
