"""External dependencies loader.""" import contextlib import importlib from pathlib import Path import sys import logging import sysconfig from types import ModuleType from typing import Iterator, Iterable import zipfile _my_dir = Path(__file__).parent _log = logging.getLogger(__name__) _env_folder = Path(__file__).parent.joinpath("venv") def load_wheel(module_name: str, submodules: Iterable[str]) -> list[ModuleType]: """Loads modules from a wheel file 'module_name*.whl'. Loads `module_name`, and if submodules are given, loads `module_name.submodule` for each of the submodules. This allows loading all required modules from the same wheel in one session, ensuring that inter-submodule references are correct. Returns the loaded modules, so [module, submodule, submodule, ...]. """ fname_prefix = _fname_prefix_from_module_name(module_name) wheel = _wheel_filename(fname_prefix) loaded_modules: list[ModuleType] = [] to_load = [module_name] + [f"{module_name}.{submodule}" for submodule in submodules] # Load the module from the wheel file. Keep a backup of sys.path so that it # can be restored later. This should ensure that future import statements # cannot find this wheel file, increasing the separation of dependencies of # this add-on from other add-ons. with _sys_path_mod_backup(wheel): for modname in to_load: try: module = importlib.import_module(modname) except ImportError as ex: raise ImportError( "Unable to load %r from %s: %s" % (modname, wheel, ex) ) from None assert isinstance(module, ModuleType) loaded_modules.append(module) _log.info("Loaded %s from %s", modname, module.__file__) assert len(loaded_modules) == len( to_load ), f"expecting to load {len(to_load)} modules, but only have {len(loaded_modules)}: {loaded_modules}" return loaded_modules def load_wheel_global(module_name: str, fname_prefix: str = "", match_platform: bool = False) -> ModuleType: """Loads a wheel from 'fname_prefix*.whl', unless the named module can be imported. This allows us to use system-installed packages before falling back to the shipped wheels. This is useful for development, less so for deployment. If `fname_prefix` is the empty string, it will use the first package from `module_name`. In other words, `module_name="pkg.subpkg"` will result in `fname_prefix="pkg"`. """ if not fname_prefix: fname_prefix = _fname_prefix_from_module_name(module_name) try: module = importlib.import_module(module_name) except ImportError as ex: _log.debug("Unable to import %s directly, will try wheel: %s", module_name, ex) else: _log.debug( "Was able to load %s from %s, no need to load wheel %s", module_name, module.__file__, fname_prefix, ) return module wheel = _wheel_filename(fname_prefix, match_platform=match_platform) wheel_filepath = str(wheel) wheel_archive = zipfile.ZipFile(wheel_filepath) wheel_archive.extractall(_env_folder) if str(_env_folder) not in sys.path: sys.path.insert(0, str(_env_folder)) try: module = importlib.import_module(module_name) except ImportError as ex: raise ImportError( "Unable to load %r from %s: %s" % (module_name, wheel, ex) ) from None _log.debug("Globally loaded %s from %s", module_name, module.__file__) return module @contextlib.contextmanager def _sys_path_mod_backup(wheel_file: Path) -> Iterator[None]: """Temporarily inserts a wheel onto sys.path. When the context exits, it restores sys.path and sys.modules, so that anything that was imported within the context remains unimportable by other modules. """ old_syspath = sys.path[:] old_sysmod = sys.modules.copy() try: sys.path.insert(0, str(wheel_file)) yield finally: # Restore without assigning a new list instance. That way references # held by other code will stay valid. sys.path[:] = old_syspath sys.modules.clear() sys.modules.update(old_sysmod) def _wheel_filename(fname_prefix: str, match_platform: bool = False) -> Path: if match_platform: platform_tag = sysconfig.get_platform().replace('-','_').replace('.','_') path_pattern = f"{fname_prefix}*{platform_tag}.whl" else: path_pattern = f"{fname_prefix}*.whl" wheels: list[Path] = list(_my_dir.glob(path_pattern)) if not wheels: raise RuntimeError("Unable to find wheel at %r" % path_pattern) # If there are multiple wheels that match, load the last-modified one. # Alphabetical sorting isn't going to cut it since BAT 1.10 was released. def modtime(filepath: Path) -> float: return filepath.stat().st_mtime wheels.sort(key=modtime) return wheels[-1] def _fname_prefix_from_module_name(module_name: str) -> str: return module_name.split(".", 1)[0]