From eb8fc4f3be7a1b5dd90313f0a8381bdaca346632 Mon Sep 17 00:00:00 2001
From: Todd Gamblin <tgamblin@llnl.gov>
Date: Sat, 21 Dec 2019 16:23:54 -0800
Subject: [PATCH] lock transactions: fix non-transactional writes

Lock transactions were actually writing *after* the lock was
released. The code was looking at the result of `release_write()` before
writing, then writing based on whether the lock was released.  This is
pretty obviously wrong.

- [x] Refactor `Lock` so that a release function can be passed to the
      `Lock` and called *only* when a lock is really released.

- [x] Refactor `LockTransaction` classes to use the release function
  instead of checking the return value of `release_read()` / `release_write()`
---
 lib/spack/llnl/util/lock.py            | 100 +++++---
 lib/spack/spack/database.py            |   7 +-
 lib/spack/spack/test/llnl/util/lock.py | 340 +++++++++++++++----------
 lib/spack/spack/util/file_cache.py     |  10 +-
 4 files changed, 285 insertions(+), 172 deletions(-)

diff --git a/lib/spack/llnl/util/lock.py b/lib/spack/llnl/util/lock.py
index 66cb067c88..c675c7c452 100644
--- a/lib/spack/llnl/util/lock.py
+++ b/lib/spack/llnl/util/lock.py
@@ -95,10 +95,6 @@ def _lock(self, op, timeout=None):
         The lock is implemented as a spin lock using a nonblocking call
         to ``lockf()``.
 
-        On acquiring an exclusive lock, the lock writes this process's
-        pid and host to the lock file, in case the holding process needs
-        to be killed later.
-
         If the lock times out, it raises a ``LockError``. If the lock is
         successfully acquired, the total wait time and the number of attempts
         is returned.
@@ -284,11 +280,19 @@ def acquire_write(self, timeout=None):
             self._writes += 1
             return False
 
-    def release_read(self):
+    def release_read(self, release_fn=None):
         """Releases a read lock.
 
-        Returns True if the last recursive lock was released, False if
-        there are still outstanding locks.
+        Arguments:
+            release_fn (callable): function to call *before* the last recursive
+                lock (read or write) is released.
+
+        If the last recursive lock will be released, then this will call
+        release_fn and return its result (if provided), or return True
+        (if release_fn was not provided).
+
+        Otherwise, we are still nested inside some other lock, so do not
+        call the release_fn and, return False.
 
         Does limited correctness checking: if a read lock is released
         when none are held, this will raise an assertion error.
@@ -300,18 +304,30 @@ def release_read(self):
             self._debug(
                 'READ LOCK: {0.path}[{0._start}:{0._length}] [Released]'
                 .format(self))
+
+            result = True
+            if release_fn is not None:
+                result = release_fn()
+
             self._unlock()      # can raise LockError.
             self._reads -= 1
-            return True
+            return result
         else:
             self._reads -= 1
             return False
 
-    def release_write(self):
+    def release_write(self, release_fn=None):
         """Releases a write lock.
 
-        Returns True if the last recursive lock was released, False if
-        there are still outstanding locks.
+        Arguments:
+            release_fn (callable): function to call before the last recursive
+                write is released.
+
+        If the last recursive *write* lock will be released, then this
+        will call release_fn and return its result (if provided), or
+        return True (if release_fn was not provided). Otherwise, we are
+        still nested inside some other write lock, so do not call the
+        release_fn, and return False.
 
         Does limited correctness checking: if a read lock is released
         when none are held, this will raise an assertion error.
@@ -323,9 +339,16 @@ def release_write(self):
             self._debug(
                 'WRITE LOCK: {0.path}[{0._start}:{0._length}] [Released]'
                 .format(self))
+
+            # we need to call release_fn before releasing the lock
+            result = True
+            if release_fn is not None:
+                result = release_fn()
+
             self._unlock()      # can raise LockError.
             self._writes -= 1
-            return True
+            return result
+
         else:
             self._writes -= 1
             return False
@@ -349,28 +372,36 @@ def _acquired_debug(self, lock_type, wait_time, nattempts):
 class LockTransaction(object):
     """Simple nested transaction context manager that uses a file lock.
 
-    This class can trigger actions when the lock is acquired for the
-    first time and released for the last.
+    Arguments:
+        lock (Lock): underlying lock for this transaction to be accquired on
+            enter and released on exit
+        acquire (callable or contextmanager): function to be called after lock
+            is acquired, or contextmanager to enter after acquire and leave
+            before release.
+        release (callable): function to be called before release. If
+            ``acquire`` is a contextmanager, this will be called *after*
+            exiting the nexted context and before the lock is released.
+        timeout (float): number of seconds to set for the timeout when
+            accquiring the lock (default no timeout)
 
     If the ``acquire_fn`` returns a value, it is used as the return value for
     ``__enter__``, allowing it to be passed as the ``as`` argument of a
     ``with`` statement.
 
     If ``acquire_fn`` returns a context manager, *its* ``__enter__`` function
-    will be called in ``__enter__`` after ``acquire_fn``, and its ``__exit__``
-    funciton will be called before ``release_fn`` in ``__exit__``, allowing you
-    to nest a context manager to be used along with the lock.
+    will be called after the lock is acquired, and its ``__exit__`` funciton
+    will be called before ``release_fn`` in ``__exit__``, allowing you to
+    nest a context manager inside this one.
 
     Timeout for lock is customizable.
 
     """
 
-    def __init__(self, lock, acquire_fn=None, release_fn=None,
-                 timeout=None):
+    def __init__(self, lock, acquire=None, release=None, timeout=None):
         self._lock = lock
         self._timeout = timeout
-        self._acquire_fn = acquire_fn
-        self._release_fn = release_fn
+        self._acquire_fn = acquire
+        self._release_fn = release
         self._as = None
 
     def __enter__(self):
@@ -383,13 +414,18 @@ def __enter__(self):
 
     def __exit__(self, type, value, traceback):
         suppress = False
-        if self._exit():
-            if self._as and hasattr(self._as, '__exit__'):
-                if self._as.__exit__(type, value, traceback):
-                    suppress = True
-            if self._release_fn:
-                if self._release_fn(type, value, traceback):
-                    suppress = True
+
+        def release_fn():
+            if self._release_fn is not None:
+                return self._release_fn(type, value, traceback)
+
+        if self._as and hasattr(self._as, '__exit__'):
+            if self._as.__exit__(type, value, traceback):
+                suppress = True
+
+        if self._exit(release_fn):
+            suppress = True
+
         return suppress
 
 
@@ -398,8 +434,8 @@ class ReadTransaction(LockTransaction):
     def _enter(self):
         return self._lock.acquire_read(self._timeout)
 
-    def _exit(self):
-        return self._lock.release_read()
+    def _exit(self, release_fn):
+        return self._lock.release_read(release_fn)
 
 
 class WriteTransaction(LockTransaction):
@@ -407,8 +443,8 @@ class WriteTransaction(LockTransaction):
     def _enter(self):
         return self._lock.acquire_write(self._timeout)
 
-    def _exit(self):
-        return self._lock.release_write()
+    def _exit(self, release_fn):
+        return self._lock.release_write(release_fn)
 
 
 class LockError(Exception):
diff --git a/lib/spack/spack/database.py b/lib/spack/spack/database.py
index e6e82f9803..243f1a20d5 100644
--- a/lib/spack/spack/database.py
+++ b/lib/spack/spack/database.py
@@ -332,11 +332,12 @@ def __init__(self, root, db_dir=None, upstream_dbs=None,
 
     def write_transaction(self):
         """Get a write lock context manager for use in a `with` block."""
-        return WriteTransaction(self.lock, self._read, self._write)
+        return WriteTransaction(
+            self.lock, acquire=self._read, release=self._write)
 
     def read_transaction(self):
         """Get a read lock context manager for use in a `with` block."""
-        return ReadTransaction(self.lock, self._read)
+        return ReadTransaction(self.lock, acquire=self._read)
 
     def prefix_lock(self, spec):
         """Get a lock on a particular spec's installation directory.
@@ -624,7 +625,7 @@ def _read_suppress_error():
                 self._data = {}
 
         transaction = WriteTransaction(
-            self.lock, _read_suppress_error, self._write
+            self.lock, acquire=_read_suppress_error, release=self._write
         )
 
         with transaction:
diff --git a/lib/spack/spack/test/llnl/util/lock.py b/lib/spack/spack/test/llnl/util/lock.py
index d8081d108c..3bf8a236b1 100644
--- a/lib/spack/spack/test/llnl/util/lock.py
+++ b/lib/spack/spack/test/llnl/util/lock.py
@@ -42,6 +42,7 @@
 actually on a shared filesystem.
 
 """
+import collections
 import os
 import socket
 import shutil
@@ -776,189 +777,258 @@ def p3(barrier):
     multiproc_test(p1, p2, p3)
 
 
-def test_transaction(lock_path):
+class AssertLock(lk.Lock):
+    """Test lock class that marks acquire/release events."""
+    def __init__(self, lock_path, vals):
+        super(AssertLock, self).__init__(lock_path)
+        self.vals = vals
+
+    def acquire_read(self, timeout=None):
+        self.assert_acquire_read()
+        result = super(AssertLock, self).acquire_read(timeout)
+        self.vals['acquired_read'] = True
+        return result
+
+    def acquire_write(self, timeout=None):
+        self.assert_acquire_write()
+        result = super(AssertLock, self).acquire_write(timeout)
+        self.vals['acquired_write'] = True
+        return result
+
+    def release_read(self, release_fn=None):
+        self.assert_release_read()
+        result = super(AssertLock, self).release_read(release_fn)
+        self.vals['released_read'] = True
+        return result
+
+    def release_write(self, release_fn=None):
+        self.assert_release_write()
+        result = super(AssertLock, self).release_write(release_fn)
+        self.vals['released_write'] = True
+        return result
+
+
+@pytest.mark.parametrize(
+    "transaction,type",
+    [(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")]
+)
+def test_transaction(lock_path, transaction, type):
+    class MockLock(AssertLock):
+        def assert_acquire_read(self):
+            assert not vals['entered_fn']
+            assert not vals['exited_fn']
+
+        def assert_release_read(self):
+            assert vals['entered_fn']
+            assert not vals['exited_fn']
+
+        def assert_acquire_write(self):
+            assert not vals['entered_fn']
+            assert not vals['exited_fn']
+
+        def assert_release_write(self):
+            assert vals['entered_fn']
+            assert not vals['exited_fn']
+
     def enter_fn():
-        vals['entered'] = True
+        # assert enter_fn is called while lock is held
+        assert vals['acquired_%s' % type]
+        vals['entered_fn'] = True
 
     def exit_fn(t, v, tb):
-        vals['exited'] = True
+        # assert exit_fn is called while lock is held
+        assert not vals['released_%s' % type]
+        vals['exited_fn'] = True
         vals['exception'] = (t or v or tb)
 
-    lock = lk.Lock(lock_path)
-    vals = {'entered': False, 'exited': False, 'exception': False}
-    with lk.ReadTransaction(lock, enter_fn, exit_fn):
-        pass
+    vals = collections.defaultdict(lambda: False)
+    lock = MockLock(lock_path, vals)
+
+    with transaction(lock, acquire=enter_fn, release=exit_fn):
+        assert vals['acquired_%s' % type]
+        assert not vals['released_%s' % type]
 
-    assert vals['entered']
-    assert vals['exited']
+    assert vals['entered_fn']
+    assert vals['exited_fn']
+    assert vals['acquired_%s' % type]
+    assert vals['released_%s' % type]
     assert not vals['exception']
 
-    vals = {'entered': False, 'exited': False, 'exception': False}
-    with lk.WriteTransaction(lock, enter_fn, exit_fn):
-        pass
 
-    assert vals['entered']
-    assert vals['exited']
-    assert not vals['exception']
+@pytest.mark.parametrize(
+    "transaction,type",
+    [(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")]
+)
+def test_transaction_with_exception(lock_path, transaction, type):
+    class MockLock(AssertLock):
+        def assert_acquire_read(self):
+            assert not vals['entered_fn']
+            assert not vals['exited_fn']
+
+        def assert_release_read(self):
+            assert vals['entered_fn']
+            assert not vals['exited_fn']
 
+        def assert_acquire_write(self):
+            assert not vals['entered_fn']
+            assert not vals['exited_fn']
+
+        def assert_release_write(self):
+            assert vals['entered_fn']
+            assert not vals['exited_fn']
 
-def test_transaction_with_exception(lock_path):
     def enter_fn():
-        vals['entered'] = True
+        assert vals['acquired_%s' % type]
+        vals['entered_fn'] = True
 
     def exit_fn(t, v, tb):
-        vals['exited'] = True
+        assert not vals['released_%s' % type]
+        vals['exited_fn'] = True
         vals['exception'] = (t or v or tb)
+        return exit_result
 
-    lock = lk.Lock(lock_path)
-
-    def do_read_with_exception():
-        with lk.ReadTransaction(lock, enter_fn, exit_fn):
-            raise Exception()
-
-    def do_write_with_exception():
-        with lk.WriteTransaction(lock, enter_fn, exit_fn):
-            raise Exception()
+    exit_result = False
+    vals = collections.defaultdict(lambda: False)
+    lock = MockLock(lock_path, vals)
 
-    vals = {'entered': False, 'exited': False, 'exception': False}
     with pytest.raises(Exception):
-        do_read_with_exception()
-    assert vals['entered']
-    assert vals['exited']
-    assert vals['exception']
+        with transaction(lock, acquire=enter_fn, release=exit_fn):
+            raise Exception()
 
-    vals = {'entered': False, 'exited': False, 'exception': False}
-    with pytest.raises(Exception):
-        do_write_with_exception()
-    assert vals['entered']
-    assert vals['exited']
+    assert vals['entered_fn']
+    assert vals['exited_fn']
     assert vals['exception']
 
+    # test suppression of exceptions from exit_fn
+    exit_result = True
+    vals.clear()
 
-def test_transaction_with_context_manager(lock_path):
-    class TestContextManager(object):
-
-        def __enter__(self):
-            vals['entered'] = True
-
-        def __exit__(self, t, v, tb):
-            vals['exited'] = True
-            vals['exception'] = (t or v or tb)
-
-    def exit_fn(t, v, tb):
-        vals['exited_fn'] = True
-        vals['exception_fn'] = (t or v or tb)
-
-    lock = lk.Lock(lock_path)
-
-    vals = {'entered': False, 'exited': False, 'exited_fn': False,
-            'exception': False, 'exception_fn': False}
-    with lk.ReadTransaction(lock, TestContextManager, exit_fn):
-        pass
+    # should not raise now.
+    with transaction(lock, acquire=enter_fn, release=exit_fn):
+        raise Exception()
 
-    assert vals['entered']
-    assert vals['exited']
-    assert not vals['exception']
+    assert vals['entered_fn']
     assert vals['exited_fn']
-    assert not vals['exception_fn']
-
-    vals = {'entered': False, 'exited': False, 'exited_fn': False,
-            'exception': False, 'exception_fn': False}
-    with lk.ReadTransaction(lock, TestContextManager):
-        pass
-
-    assert vals['entered']
-    assert vals['exited']
-    assert not vals['exception']
-    assert not vals['exited_fn']
-    assert not vals['exception_fn']
+    assert vals['exception']
 
-    vals = {'entered': False, 'exited': False, 'exited_fn': False,
-            'exception': False, 'exception_fn': False}
-    with lk.WriteTransaction(lock, TestContextManager, exit_fn):
-        pass
 
-    assert vals['entered']
-    assert vals['exited']
-    assert not vals['exception']
-    assert vals['exited_fn']
-    assert not vals['exception_fn']
+@pytest.mark.parametrize(
+    "transaction,type",
+    [(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")]
+)
+def test_transaction_with_context_manager(lock_path, transaction, type):
+    class MockLock(AssertLock):
+        def assert_acquire_read(self):
+            assert not vals['entered_ctx']
+            assert not vals['exited_ctx']
 
-    vals = {'entered': False, 'exited': False, 'exited_fn': False,
-            'exception': False, 'exception_fn': False}
-    with lk.WriteTransaction(lock, TestContextManager):
-        pass
+        def assert_release_read(self):
+            assert vals['entered_ctx']
+            assert vals['exited_ctx']
 
-    assert vals['entered']
-    assert vals['exited']
-    assert not vals['exception']
-    assert not vals['exited_fn']
-    assert not vals['exception_fn']
+        def assert_acquire_write(self):
+            assert not vals['entered_ctx']
+            assert not vals['exited_ctx']
 
+        def assert_release_write(self):
+            assert vals['entered_ctx']
+            assert vals['exited_ctx']
 
-def test_transaction_with_context_manager_and_exception(lock_path):
     class TestContextManager(object):
         def __enter__(self):
-            vals['entered'] = True
+            vals['entered_ctx'] = True
 
         def __exit__(self, t, v, tb):
-            vals['exited'] = True
-            vals['exception'] = (t or v or tb)
+            assert not vals['released_%s' % type]
+            vals['exited_ctx'] = True
+            vals['exception_ctx'] = (t or v or tb)
+            return exit_ctx_result
 
     def exit_fn(t, v, tb):
+        assert not vals['released_%s' % type]
         vals['exited_fn'] = True
         vals['exception_fn'] = (t or v or tb)
+        return exit_fn_result
 
-    lock = lk.Lock(lock_path)
-
-    def do_read_with_exception(exit_fn):
-        with lk.ReadTransaction(lock, TestContextManager, exit_fn):
-            raise Exception()
+    exit_fn_result, exit_ctx_result = False, False
+    vals = collections.defaultdict(lambda: False)
+    lock = MockLock(lock_path, vals)
 
-    def do_write_with_exception(exit_fn):
-        with lk.WriteTransaction(lock, TestContextManager, exit_fn):
-            raise Exception()
+    with transaction(lock, acquire=TestContextManager, release=exit_fn):
+        pass
 
-    vals = {'entered': False, 'exited': False, 'exited_fn': False,
-            'exception': False, 'exception_fn': False}
-    with pytest.raises(Exception):
-        do_read_with_exception(exit_fn)
-    assert vals['entered']
-    assert vals['exited']
-    assert vals['exception']
+    assert vals['entered_ctx']
+    assert vals['exited_ctx']
     assert vals['exited_fn']
-    assert vals['exception_fn']
-
-    vals = {'entered': False, 'exited': False, 'exited_fn': False,
-            'exception': False, 'exception_fn': False}
-    with pytest.raises(Exception):
-        do_read_with_exception(None)
-    assert vals['entered']
-    assert vals['exited']
-    assert vals['exception']
-    assert not vals['exited_fn']
+    assert not vals['exception_ctx']
     assert not vals['exception_fn']
 
-    vals = {'entered': False, 'exited': False, 'exited_fn': False,
-            'exception': False, 'exception_fn': False}
-    with pytest.raises(Exception):
-        do_write_with_exception(exit_fn)
-    assert vals['entered']
-    assert vals['exited']
-    assert vals['exception']
-    assert vals['exited_fn']
-    assert vals['exception_fn']
+    vals.clear()
+    with transaction(lock, acquire=TestContextManager):
+        pass
 
-    vals = {'entered': False, 'exited': False, 'exited_fn': False,
-            'exception': False, 'exception_fn': False}
-    with pytest.raises(Exception):
-        do_write_with_exception(None)
-    assert vals['entered']
-    assert vals['exited']
-    assert vals['exception']
+    assert vals['entered_ctx']
+    assert vals['exited_ctx']
     assert not vals['exited_fn']
+    assert not vals['exception_ctx']
     assert not vals['exception_fn']
 
+    # below are tests for exceptions with and without suppression
+    def assert_ctx_and_fn_exception(raises=True):
+        vals.clear()
+
+        if raises:
+            with pytest.raises(Exception):
+                with transaction(
+                        lock, acquire=TestContextManager, release=exit_fn):
+                    raise Exception()
+        else:
+            with transaction(
+                    lock, acquire=TestContextManager, release=exit_fn):
+                raise Exception()
+
+        assert vals['entered_ctx']
+        assert vals['exited_ctx']
+        assert vals['exited_fn']
+        assert vals['exception_ctx']
+        assert vals['exception_fn']
+
+    def assert_only_ctx_exception(raises=True):
+        vals.clear()
+
+        if raises:
+            with pytest.raises(Exception):
+                with transaction(lock, acquire=TestContextManager):
+                    raise Exception()
+        else:
+            with transaction(lock, acquire=TestContextManager):
+                raise Exception()
+
+        assert vals['entered_ctx']
+        assert vals['exited_ctx']
+        assert not vals['exited_fn']
+        assert vals['exception_ctx']
+        assert not vals['exception_fn']
+
+    # no suppression
+    assert_ctx_and_fn_exception(raises=True)
+    assert_only_ctx_exception(raises=True)
+
+    # suppress exception only in function
+    exit_fn_result, exit_ctx_result = True, False
+    assert_ctx_and_fn_exception(raises=False)
+    assert_only_ctx_exception(raises=True)
+
+    # suppress exception only in context
+    exit_fn_result, exit_ctx_result = False, True
+    assert_ctx_and_fn_exception(raises=False)
+    assert_only_ctx_exception(raises=False)
+
+    # suppress exception in function and context
+    exit_fn_result, exit_ctx_result = True, True
+    assert_ctx_and_fn_exception(raises=False)
+    assert_only_ctx_exception(raises=False)
+
 
 def test_lock_debug_output(lock_path):
     host = socket.getfqdn()
diff --git a/lib/spack/spack/util/file_cache.py b/lib/spack/spack/util/file_cache.py
index d56f2b33c5..0227edf155 100644
--- a/lib/spack/spack/util/file_cache.py
+++ b/lib/spack/spack/util/file_cache.py
@@ -107,7 +107,8 @@ def read_transaction(self, key):
 
         """
         return ReadTransaction(
-            self._get_lock(key), lambda: open(self.cache_path(key)))
+            self._get_lock(key), acquire=lambda: open(self.cache_path(key))
+        )
 
     def write_transaction(self, key):
         """Get a write transaction on a file cache item.
@@ -117,6 +118,10 @@ def write_transaction(self, key):
         moves the file into place on top of the old file atomically.
 
         """
+        # TODO: this nested context manager adds a lot of complexity and
+        # TODO: is pretty hard to reason about in llnl.util.lock. At some
+        # TODO: point we should just replace it with functions and simplify
+        # TODO: the locking code.
         class WriteContextManager(object):
 
             def __enter__(cm):  # noqa
@@ -142,7 +147,8 @@ def __exit__(cm, type, value, traceback):  # noqa
                 else:
                     os.rename(cm.tmp_filename, cm.orig_filename)
 
-        return WriteTransaction(self._get_lock(key), WriteContextManager)
+        return WriteTransaction(
+            self._get_lock(key), acquire=WriteContextManager)
 
     def mtime(self, key):
         """Return modification time of cache file, or 0 if it does not exist.
-- 
GitLab