diff options
Diffstat (limited to 'wlauto/core/extension_loader.py')
-rw-r--r-- | wlauto/core/extension_loader.py | 400 |
1 files changed, 400 insertions, 0 deletions
diff --git a/wlauto/core/extension_loader.py b/wlauto/core/extension_loader.py new file mode 100644 index 00000000..0263f830 --- /dev/null +++ b/wlauto/core/extension_loader.py @@ -0,0 +1,400 @@ +# Copyright 2013-2015 ARM Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import os +import sys +import inspect +import imp +import string +import logging +from functools import partial +from collections import OrderedDict + +from wlauto.core.bootstrap import settings +from wlauto.core.extension import Extension +from wlauto.exceptions import NotFoundError, LoaderError +from wlauto.utils.misc import walk_modules, load_class, merge_lists, merge_dicts, get_article +from wlauto.utils.types import identifier + + +MODNAME_TRANS = string.maketrans(':/\\.', '____') + + +class ExtensionLoaderItem(object): + + def __init__(self, ext_tuple): + self.name = ext_tuple.name + self.default_package = ext_tuple.default_package + self.default_path = ext_tuple.default_path + self.cls = load_class(ext_tuple.cls) + + +class GlobalParameterAlias(object): + """ + Represents a "global alias" for an extension parameter. A global alias + is specified at the top-level of config rather namespaced under an extension + name. + + Multiple extensions may have parameters with the same global_alias if they are + part of the same inheritance hierarchy and one parameter is an override of the + other. This class keeps track of all such cases in its extensions dict. + + """ + + def __init__(self, name): + self.name = name + self.extensions = {} + + def iteritems(self): + for ext in self.extensions.itervalues(): + yield (self.get_param(ext), ext) + + def get_param(self, ext): + for param in ext.parameters: + if param.global_alias == self.name: + return param + message = 'Extension {} does not have a parameter with global alias {}' + raise ValueError(message.format(ext.name, self.name)) + + def update(self, other_ext): + self._validate_ext(other_ext) + self.extensions[other_ext.name] = other_ext + + def _validate_ext(self, other_ext): + other_param = self.get_param(other_ext) + for param, ext in self.iteritems(): + if ((not (issubclass(ext, other_ext) or issubclass(other_ext, ext))) and + other_param.kind != param.kind): + message = 'Duplicate global alias {} declared in {} and {} extensions with different types' + raise LoaderError(message.format(self.name, ext.name, other_ext.name)) + if not param.name == other_param.name: + message = 'Two params {} in {} and {} in {} both declare global alias {}' + raise LoaderError(message.format(param.name, ext.name, + other_param.name, other_ext.name, self.name)) + + def __str__(self): + text = 'GlobalAlias({} => {})' + extlist = ', '.join(['{}.{}'.format(e.name, p.name) for p, e in self.iteritems()]) + return text.format(self.name, extlist) + + +class ExtensionLoader(object): + """ + Discovers, enumerates and loads available devices, configs, etc. + The loader will attempt to discover things on construction by looking + in predetermined set of locations defined by default_paths. Optionally, + additional locations may specified through paths parameter that must + be a list of additional Python module paths (i.e. dot-delimited). + + """ + + _instance = None + + # Singleton + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(ExtensionLoader, cls).__new__(cls, *args, **kwargs) + else: + for k, v in kwargs.iteritems(): + if not hasattr(cls._instance, k): + raise ValueError('Invalid parameter for ExtensionLoader: {}'.format(k)) + setattr(cls._instance, k, v) + return cls._instance + + def set_load_defaults(self, value): + self._load_defaults = value + if value: + self.packages = merge_lists(self.default_packages, self.packages, duplicates='last') + + def get_load_defaults(self): + return self._load_defaults + + load_defaults = property(get_load_defaults, set_load_defaults) + + def __init__(self, packages=None, paths=None, ignore_paths=None, keep_going=False, load_defaults=True): + """ + params:: + + :packages: List of packages to load extensions from. + :paths: List of paths to be searched for Python modules containing + WA extensions. + :ignore_paths: List of paths to ignore when search for WA extensions (these would + typically be subdirectories of one or more locations listed in + ``paths`` parameter. + :keep_going: Specifies whether to keep going if an error occurs while loading + extensions. + :load_defaults: Specifies whether extension should be loaded from default locations + (WA package, and user's WA directory) as well as the packages/paths + specified explicitly in ``packages`` and ``paths`` parameters. + + """ + self._load_defaults = None + self.logger = logging.getLogger('ExtensionLoader') + self.keep_going = keep_going + self.extension_kinds = {ext_tuple.name: ExtensionLoaderItem(ext_tuple) + for ext_tuple in settings.extensions} + self.default_packages = [ext.default_package for ext in self.extension_kinds.values()] + + self.packages = packages or [] + self.load_defaults = load_defaults + self.paths = paths or [] + self.ignore_paths = ignore_paths or [] + self.extensions = {} + self.aliases = {} + self.global_param_aliases = {} + # create an empty dict for each extension type to store discovered + # extensions. + for ext in self.extension_kinds.values(): + setattr(self, '_' + ext.name, {}) + self._load_from_packages(self.packages) + self._load_from_paths(self.paths, self.ignore_paths) + + def update(self, packages=None, paths=None, ignore_paths=None): + """ Load extensions from the specified paths/packages + without clearing or reloading existing extension. """ + if packages: + self.packages.extend(packages) + self._load_from_packages(packages) + if paths: + self.paths.extend(paths) + self.ignore_paths.extend(ignore_paths or []) + self._load_from_paths(paths, ignore_paths or []) + + def clear(self): + """ Clear all discovered items. """ + self.extensions.clear() + for ext in self.extension_kinds.values(): + self._get_store(ext).clear() + + def reload(self): + """ Clear all discovered items and re-run the discovery. """ + self.clear() + self._load_from_packages(self.packages) + self._load_from_paths(self.paths, self.ignore_paths) + + def get_extension_class(self, name, kind=None): + """ + Return the class for the specified extension if found or raises ``ValueError``. + + """ + name, _ = self.resolve_alias(name) + if kind is None: + return self.extensions[name] + ext = self.extension_kinds.get(kind) + if ext is None: + raise ValueError('Unknown extension type: {}'.format(kind)) + store = self._get_store(ext) + if name not in store: + raise NotFoundError('Extensions {} is not {} {}.'.format(name, get_article(kind), kind)) + return store[name] + + def get_extension(self, name, *args, **kwargs): + """ + Return extension of the specified kind with the specified name. Any additional + parameters will be passed to the extension's __init__. + + """ + name, base_kwargs = self.resolve_alias(name) + kind = kwargs.pop('kind', None) + kwargs = merge_dicts(base_kwargs, kwargs, list_duplicates='last', dict_type=OrderedDict) + cls = self.get_extension_class(name, kind) + extension = _instantiate(cls, args, kwargs) + extension.load_modules(self) + return extension + + def get_default_config(self, ext_name): + """ + Returns the default configuration for the specified extension name. The name may be an alias, + in which case, the returned config will be augmented with appropriate alias overrides. + + """ + real_name, alias_config = self.resolve_alias(ext_name) + base_default_config = self.get_extension_class(real_name).get_default_config() + return merge_dicts(base_default_config, alias_config, list_duplicates='last', dict_type=OrderedDict) + + def list_extensions(self, kind=None): + """ + List discovered extension classes. Optionally, only list extensions of a + particular type. + + """ + if kind is None: + return self.extensions.values() + if kind not in self.extension_kinds: + raise ValueError('Unknown extension type: {}'.format(kind)) + return self._get_store(self.extension_kinds[kind]).values() + + def has_extension(self, name, kind=None): + """ + Returns ``True`` if an extensions with the specified ``name`` has been + discovered by the loader. If ``kind`` was specified, only returns ``True`` + if the extension has been found, *and* it is of the specified kind. + + """ + try: + self.get_extension_class(name, kind) + return True + except NotFoundError: + return False + + def resolve_alias(self, alias_name): + """ + Try to resolve the specified name as an extension alias. Returns a + two-tuple, the first value of which is actual extension name, and the + second is a dict of parameter values for this alias. If the name passed + is already an extension name, then the result is ``(alias_name, {})``. + + """ + alias_name = identifier(alias_name.lower()) + if alias_name in self.extensions: + return (alias_name, {}) + if alias_name in self.aliases: + alias = self.aliases[alias_name] + return (alias.extension_name, alias.params) + raise NotFoundError('Could not find extension or alias "{}"'.format(alias_name)) + + # Internal methods. + + def __getattr__(self, name): + """ + This resolves methods for specific extensions types based on corresponding + generic extension methods. So it's possible to say things like :: + + loader.get_device('foo') + + instead of :: + + loader.get_extension('foo', kind='device') + + """ + if name.startswith('get_'): + name = name.replace('get_', '', 1) + if name in self.extension_kinds: + return partial(self.get_extension, kind=name) + if name.startswith('list_'): + name = name.replace('list_', '', 1).rstrip('s') + if name in self.extension_kinds: + return partial(self.list_extensions, kind=name) + if name.startswith('has_'): + name = name.replace('has_', '', 1) + if name in self.extension_kinds: + return partial(self.has_extension, kind=name) + raise AttributeError(name) + + def _get_store(self, ext): + name = getattr(ext, 'name', ext) + return getattr(self, '_' + name) + + def _load_from_packages(self, packages): + try: + for package in packages: + for module in walk_modules(package): + self._load_module(module) + except ImportError as e: + message = 'Problem loading extensions from extra packages: {}' + raise LoaderError(message.format(e.message)) + + def _load_from_paths(self, paths, ignore_paths): + self.logger.debug('Loading from paths.') + for path in paths: + self.logger.debug('Checking path %s', path) + for root, _, files in os.walk(path): + should_skip = False + for igpath in ignore_paths: + if root.startswith(igpath): + should_skip = True + break + if should_skip: + continue + for fname in files: + if not os.path.splitext(fname)[1].lower() == '.py': + continue + filepath = os.path.join(root, fname) + try: + modname = os.path.splitext(filepath[1:])[0].translate(MODNAME_TRANS) + module = imp.load_source(modname, filepath) + self._load_module(module) + except (SystemExit, ImportError), e: + if self.keep_going: + self.logger.warn('Failed to load {}'.format(filepath)) + self.logger.warn('Got: {}'.format(e)) + else: + raise LoaderError('Failed to load {}'.format(filepath), sys.exc_info()) + + def _load_module(self, module): # NOQA pylint: disable=too-many-branches + self.logger.debug('Checking module %s', module.__name__) + for obj in vars(module).itervalues(): + if inspect.isclass(obj): + if not issubclass(obj, Extension) or not hasattr(obj, 'name') or not obj.name: + continue + try: + for ext in self.extension_kinds.values(): + if issubclass(obj, ext.cls): + self._add_found_extension(obj, ext) + break + else: # did not find a matching Extension type + message = 'Unknown extension type for {} (type: {})' + raise LoaderError(message.format(obj.name, obj.__class__.__name__)) + except LoaderError as e: + if self.keep_going: + self.logger.warning(e) + else: + raise e + + def _add_found_extension(self, obj, ext): + """ + :obj: Found extension class + :ext: matching extension item. + """ + self.logger.debug('\tAdding %s %s', ext.name, obj.name) + key = identifier(obj.name.lower()) + obj.kind = ext.name + if key in self.extensions or key in self.aliases: + raise LoaderError('{} {} already exists.'.format(ext.name, obj.name)) + # Extensions are tracked both, in a common extensions + # dict, and in per-extension kind dict (as retrieving + # extensions by kind is a common use case. + self.extensions[key] = obj + store = self._get_store(ext) + store[key] = obj + for alias in obj.aliases: + if alias in self.extensions or alias in self.aliases: + raise LoaderError('{} {} already exists.'.format(ext.name, obj.name)) + self.aliases[alias.name] = alias + + # Update global aliases list. If a global alias is already in the list, + # then make sure this extension is in the same parent/child hierarchy + # as the one already found. + for param in obj.parameters: + if param.global_alias: + if param.global_alias not in self.global_param_aliases: + ga = GlobalParameterAlias(param.global_alias) + ga.update(obj) + self.global_param_aliases[ga.name] = ga + else: # global alias already exists. + self.global_param_aliases[param.global_alias].update(obj) + + +# Utility functions. + +def _instantiate(cls, args=None, kwargs=None): + args = [] if args is None else args + kwargs = {} if kwargs is None else kwargs + try: + return cls(*args, **kwargs) + except Exception: + raise LoaderError('Could not load {}'.format(cls), sys.exc_info()) + |