Skip to content
Snippets Groups Projects
Commit c0c08799 authored by Todd Gamblin's avatar Todd Gamblin
Browse files

Better extension activation/deactivation

parent 82dc935a
No related branches found
No related tags found
No related merge requests found
...@@ -89,10 +89,10 @@ def extensions(parser, args): ...@@ -89,10 +89,10 @@ def extensions(parser, args):
spack.cmd.find.display_specs(installed, mode=args.mode) spack.cmd.find.display_specs(installed, mode=args.mode)
# List specs of activated extensions. # List specs of activated extensions.
activated = spack.install_layout.get_extensions(spec) activated = spack.install_layout.extension_map(spec)
print print
if not activated: if not activated:
tty.msg("None activated.") tty.msg("None activated.")
return return
tty.msg("%d currently activated:" % len(activated)) tty.msg("%d currently activated:" % len(activated))
spack.cmd.find.display_specs(activated, mode=args.mode) spack.cmd.find.display_specs(activated.values(), mode=args.mode)
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
import exceptions import exceptions
import hashlib import hashlib
import shutil import shutil
import tempfile
from contextlib import closing from contextlib import closing
import llnl.util.tty as tty import llnl.util.tty as tty
...@@ -84,17 +85,38 @@ def make_path_for_spec(self, spec): ...@@ -84,17 +85,38 @@ def make_path_for_spec(self, spec):
raise NotImplementedError() raise NotImplementedError()
def get_extensions(self, spec): def extension_map(self, spec):
"""Get a set of currently installed extension packages for a spec.""" """Get a dict of currently installed extension packages for a spec.
Dict maps { name : extension_spec }
Modifying dict does not affect internals of this layout.
"""
raise NotImplementedError()
def check_extension_conflict(self, spec, ext_spec):
"""Ensure that ext_spec can be activated in spec.
If not, raise ExtensionAlreadyInstalledError or
ExtensionConflictError.
"""
raise NotImplementedError()
def check_activated(self, spec, ext_spec):
"""Ensure that ext_spec can be removed from spec.
If not, raise NoSuchExtensionError.
"""
raise NotImplementedError() raise NotImplementedError()
def add_extension(self, spec, extension_spec): def add_extension(self, spec, ext_spec):
"""Add to the list of currently installed extensions.""" """Add to the list of currently installed extensions."""
raise NotImplementedError() raise NotImplementedError()
def remove_extension(self, spec, extension_spec): def remove_extension(self, spec, ext_spec):
"""Remove from the list of currently installed extensions.""" """Remove from the list of currently installed extensions."""
raise NotImplementedError() raise NotImplementedError()
...@@ -173,6 +195,8 @@ def __init__(self, root, **kwargs): ...@@ -173,6 +195,8 @@ def __init__(self, root, **kwargs):
self.spec_file_name = spec_file_name self.spec_file_name = spec_file_name
self.extension_file_name = extension_file_name self.extension_file_name = extension_file_name
# Cache of already written/read extension maps.
self._extension_maps = {}
@property @property
def hidden_file_paths(self): def hidden_file_paths(self):
...@@ -271,54 +295,94 @@ def extension_file_path(self, spec): ...@@ -271,54 +295,94 @@ def extension_file_path(self, spec):
return join_path(self.path_for_spec(spec), self.extension_file_name) return join_path(self.path_for_spec(spec), self.extension_file_name)
def get_extensions(self, spec): def _extension_map(self, spec):
"""Get a dict<name -> spec> for all extensions currnetly
installed for this package."""
_check_concrete(spec) _check_concrete(spec)
extensions = set() if not spec in self._extension_maps:
path = self.extension_file_path(spec) path = self.extension_file_path(spec)
if os.path.exists(path): if not os.path.exists(path):
with closing(open(path)) as ext_file: self._extension_maps[spec] = {}
for line in ext_file:
try: else:
extensions.add(Spec(line.strip())) exts = {}
except spack.error.SpackError, e: with closing(open(path)) as ext_file:
raise InvalidExtensionSpecError(str(e)) for line in ext_file:
return extensions try:
spec = Spec(line.strip())
exts[spec.name] = spec
except spack.error.SpackError, e:
# TODO: do something better here -- should be
# resilient to corrupt files.
raise InvalidExtensionSpecError(str(e))
self._extension_maps[spec] = exts
return self._extension_maps[spec]
def extension_map(self, spec):
"""Defensive copying version of _extension_map() for external API."""
return self._extension_map(spec).copy()
def check_extension_conflict(self, spec, ext_spec):
exts = self._extension_map(spec)
if ext_spec.name in exts:
installed_spec = exts[ext_spec.name]
if ext_spec == installed_spec:
raise ExtensionAlreadyInstalledError(spec, ext_spec)
else:
raise ExtensionConflictError(spec, ext_spec, installed_spec)
def check_activated(self, spec, ext_spec):
exts = self._extension_map(spec)
if (not ext_spec.name in exts) or (ext_spec != exts[ext_spec.name]):
raise NoSuchExtensionError(spec, ext_spec)
def write_extensions(self, spec, extensions):
def _write_extensions(self, spec, extensions):
path = self.extension_file_path(spec) path = self.extension_file_path(spec)
with closing(open(path, 'w')) as spec_file:
for extension in sorted(extensions): # Create a temp file in the same directory as the actual file.
spec_file.write("%s\n" % extension) dirname, basename = os.path.split(path)
tmp = tempfile.NamedTemporaryFile(
prefix=basename, dir=dirname, delete=False)
# Write temp file.
with closing(tmp):
for extension in sorted(extensions.values()):
tmp.write("%s\n" % extension)
# Atomic update by moving tmpfile on top of old one.
os.rename(tmp.name, path)
def add_extension(self, spec, extension_spec): def add_extension(self, spec, ext_spec):
_check_concrete(spec) _check_concrete(spec)
_check_concrete(extension_spec) _check_concrete(ext_spec)
exts = self.get_extensions(spec) # Check whether it's already installed or if it's a conflict.
if extension_spec in exts: exts = self.extension_map(spec)
raise ExtensionAlreadyInstalledError(spec, extension_spec) self.check_extension_conflict(spec, ext_spec)
else:
for already_installed in exts:
if spec.name == extension_spec.name:
raise ExtensionConflictError(spec, extension_spec, already_installed)
exts.add(extension_spec) # do the actual adding.
self.write_extensions(spec, exts) exts[ext_spec.name] = ext_spec
self._write_extensions(spec, exts)
def remove_extension(self, spec, extension_spec): def remove_extension(self, spec, ext_spec):
_check_concrete(spec) _check_concrete(spec)
_check_concrete(extension_spec) _check_concrete(ext_spec)
exts = self.get_extensions(spec) # Make sure it's installed before removing.
if not extension_spec in exts: exts = self.extension_map(spec)
raise NoSuchExtensionError(spec, extension_spec) self.check_activated(spec, ext_spec)
exts.remove(extension_spec) # do the actual removing.
self.write_extensions(spec, exts) del exts[ext_spec.name]
self._write_extensions(spec, exts)
class DirectoryLayoutError(SpackError): class DirectoryLayoutError(SpackError):
...@@ -365,24 +429,24 @@ def __init__(self, message): ...@@ -365,24 +429,24 @@ def __init__(self, message):
class ExtensionAlreadyInstalledError(DirectoryLayoutError): class ExtensionAlreadyInstalledError(DirectoryLayoutError):
"""Raised when an extension is added to a package that already has it.""" """Raised when an extension is added to a package that already has it."""
def __init__(self, spec, extension_spec): def __init__(self, spec, ext_spec):
super(ExtensionAlreadyInstalledError, self).__init__( super(ExtensionAlreadyInstalledError, self).__init__(
"%s is already installed in %s" % (extension_spec.short_spec, spec.short_spec)) "%s is already installed in %s" % (ext_spec.short_spec, spec.short_spec))
class ExtensionConflictError(DirectoryLayoutError): class ExtensionConflictError(DirectoryLayoutError):
"""Raised when an extension is added to a package that already has it.""" """Raised when an extension is added to a package that already has it."""
def __init__(self, spec, extension_spec, conflict): def __init__(self, spec, ext_spec, conflict):
super(ExtensionConflictError, self).__init__( super(ExtensionConflictError, self).__init__(
"%s cannot be installed in %s because it conflicts with %s."% ( "%s cannot be installed in %s because it conflicts with %s."% (
extension_spec.short_spec, spec.short_spec, conflict.short_spec)) ext_spec.short_spec, spec.short_spec, conflict.short_spec))
class NoSuchExtensionError(DirectoryLayoutError): class NoSuchExtensionError(DirectoryLayoutError):
"""Raised when an extension isn't there on remove.""" """Raised when an extension isn't there on remove."""
def __init__(self, spec, extension_spec): def __init__(self, spec, ext_spec):
super(NoSuchExtensionError, self).__init__( super(NoSuchExtensionError, self).__init__(
"%s cannot be removed from %s because it's not installed."% ( "%s cannot be removed from %s because it's not installed."% (
extension_spec.short_spec, spec.short_spec)) ext_spec.short_spec, spec.short_spec))
...@@ -534,7 +534,8 @@ def activated(self): ...@@ -534,7 +534,8 @@ def activated(self):
if not self.is_extension: if not self.is_extension:
raise ValueError("is_extension called on package that is not an extension.") raise ValueError("is_extension called on package that is not an extension.")
return self.spec in spack.install_layout.get_extensions(self.extendee_spec) exts = spack.install_layout.extension_map(self.extendee_spec)
return (self.name in exts) and (exts[self.name] == self.spec)
def preorder_traversal(self, visited=None, **kwargs): def preorder_traversal(self, visited=None, **kwargs):
...@@ -987,6 +988,8 @@ def do_activate(self): ...@@ -987,6 +988,8 @@ def do_activate(self):
activate() directly. activate() directly.
""" """
self._sanity_check_extension() self._sanity_check_extension()
spack.install_layout.check_extension_conflict(self.extendee_spec, self.spec)
self.extendee_spec.package.activate(self, **self.extendee_args) self.extendee_spec.package.activate(self, **self.extendee_args)
spack.install_layout.add_extension(self.extendee_spec, self.spec) spack.install_layout.add_extension(self.extendee_spec, self.spec)
...@@ -1014,12 +1017,22 @@ def ignore(filename): ...@@ -1014,12 +1017,22 @@ def ignore(filename):
tree.merge(self.prefix, ignore=ignore) tree.merge(self.prefix, ignore=ignore)
def do_deactivate(self): def do_deactivate(self, **kwargs):
"""Called on the extension to invoke extendee's deactivate() method.""" """Called on the extension to invoke extendee's deactivate() method."""
force = kwargs.get('force', False)
self._sanity_check_extension() self._sanity_check_extension()
# Allow a force deactivate to happen. This can unlink
# spurious files if something was corrupted.
if not force:
spack.install_layout.check_activated(self.extendee_spec, self.spec)
self.extendee_spec.package.deactivate(self, **self.extendee_args) self.extendee_spec.package.deactivate(self, **self.extendee_args)
if self.spec in spack.install_layout.get_extensions(self.extendee_spec): # redundant activation check -- makes SURE the spec is not
# still activated even if something was wrong above.
if self.activated:
spack.install_layout.remove_extension(self.extendee_spec, self.spec) spack.install_layout.remove_extension(self.extendee_spec, self.spec)
tty.msg("Deactivated extension %s for %s." tty.msg("Deactivated extension %s for %s."
......
...@@ -98,9 +98,9 @@ def ignore(filename): ...@@ -98,9 +98,9 @@ def ignore(filename):
return ignore return ignore
def write_easy_install_pth(self, extensions): def write_easy_install_pth(self, exts):
paths = [] paths = []
for ext in extensions: for ext in sorted(exts.values()):
ext_site_packages = os.path.join(ext.prefix, self.site_packages_dir) ext_site_packages = os.path.join(ext.prefix, self.site_packages_dir)
easy_pth = "%s/easy-install.pth" % ext_site_packages easy_pth = "%s/easy-install.pth" % ext_site_packages
...@@ -139,15 +139,15 @@ def activate(self, ext_pkg, **args): ...@@ -139,15 +139,15 @@ def activate(self, ext_pkg, **args):
args.update(ignore=self.python_ignore(ext_pkg, args)) args.update(ignore=self.python_ignore(ext_pkg, args))
super(Python, self).activate(ext_pkg, **args) super(Python, self).activate(ext_pkg, **args)
extensions = set(spack.install_layout.get_extensions(self.spec)) exts = spack.install_layout.extension_map(self.spec)
extensions.add(ext_pkg.spec) exts[ext_pkg.name] = ext_pkg.spec
self.write_easy_install_pth(extensions) self.write_easy_install_pth(exts)
def deactivate(self, ext_pkg, **args): def deactivate(self, ext_pkg, **args):
args.update(ignore=self.python_ignore(ext_pkg, args)) args.update(ignore=self.python_ignore(ext_pkg, args))
super(Python, self).deactivate(ext_pkg, **args) super(Python, self).deactivate(ext_pkg, **args)
extensions = set(spack.install_layout.get_extensions(self.spec)) exts = spack.install_layout.extension_map(self.spec)
extensions.remove(ext_pkg.spec) del exts[ext_pkg.name]
self.write_easy_install_pth(extensions) self.write_easy_install_pth(exts)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment