From b3a5f2e3c3ce4e3e9836301504f5b5987117248e Mon Sep 17 00:00:00 2001
From: Todd Gamblin <tgamblin@llnl.gov>
Date: Sat, 21 Dec 2019 16:29:53 -0800
Subject: [PATCH] lock transactions: ensure that nested write transactions
 write

If a write transaction was nested inside a read transaction, it would not
write properly on release, e.g., in a sequence like this, inside our
`LockTransaction` class:

```
1  with spack.store.db.read_transaction():
2    with spack.store.db.write_transaction():
3      ...
4  with spack.store.db.read_transaction():
   ...
```

The WriteTransaction on line 2 had no way of knowing that its
`__exit__()` call was the last *write* in the nesting, and it would skip
calling its write function.

The `__exit__()` call of the `ReadTransaction` on line 1 wouldn't know
how to write, and the file would never be written.

The DB would be correct in memory, but the `ReadTransaction` on line 4
would re-read the whole DB assuming that other processes may have
modified it.  Since the DB was never written, we got stale data.

- [x] Make `Lock.release_write()` return `True` whenever we release the
      *last write* in a nest.
---
 lib/spack/llnl/util/lock.py            |  8 +++-
 lib/spack/spack/test/llnl/util/lock.py | 57 ++++++++++++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/lib/spack/llnl/util/lock.py b/lib/spack/llnl/util/lock.py
index c675c7c452..86a45e2d7c 100644
--- a/lib/spack/llnl/util/lock.py
+++ b/lib/spack/llnl/util/lock.py
@@ -351,7 +351,13 @@ def release_write(self, release_fn=None):
 
         else:
             self._writes -= 1
-            return False
+
+            # when the last *write* is released, we call release_fn here
+            # instead of immediately before releasing the lock.
+            if self._writes == 0:
+                return release_fn() if release_fn is not None else True
+            else:
+                return False
 
     def _debug(self, *args):
         tty.debug(*args)
diff --git a/lib/spack/spack/test/llnl/util/lock.py b/lib/spack/spack/test/llnl/util/lock.py
index 3bf8a236b1..2b0892a25e 100644
--- a/lib/spack/spack/test/llnl/util/lock.py
+++ b/lib/spack/spack/test/llnl/util/lock.py
@@ -783,6 +783,12 @@ def __init__(self, lock_path, vals):
         super(AssertLock, self).__init__(lock_path)
         self.vals = vals
 
+    # assert hooks for subclasses
+    assert_acquire_read = lambda self: None
+    assert_acquire_write = lambda self: None
+    assert_release_read = lambda self: None
+    assert_release_write = lambda self: None
+
     def acquire_read(self, timeout=None):
         self.assert_acquire_read()
         result = super(AssertLock, self).acquire_read(timeout)
@@ -1030,6 +1036,57 @@ def assert_only_ctx_exception(raises=True):
     assert_only_ctx_exception(raises=False)
 
 
+def test_nested_write_transaction(lock_path):
+    """Ensure that the outermost write transaction writes."""
+
+    def write(t, v, tb):
+        vals['wrote'] = True
+
+    vals = collections.defaultdict(lambda: False)
+    lock = AssertLock(lock_path, vals)
+
+    # write/write
+    with lk.WriteTransaction(lock, release=write):
+        assert not vals['wrote']
+        with lk.WriteTransaction(lock, release=write):
+            assert not vals['wrote']
+        assert not vals['wrote']
+    assert vals['wrote']
+
+    # read/write
+    vals.clear()
+    with lk.ReadTransaction(lock):
+        assert not vals['wrote']
+        with lk.WriteTransaction(lock, release=write):
+            assert not vals['wrote']
+        assert vals['wrote']
+
+    # write/read/write
+    vals.clear()
+    with lk.WriteTransaction(lock, release=write):
+        assert not vals['wrote']
+        with lk.ReadTransaction(lock):
+            assert not vals['wrote']
+            with lk.WriteTransaction(lock, release=write):
+                assert not vals['wrote']
+            assert not vals['wrote']
+        assert not vals['wrote']
+    assert vals['wrote']
+
+    # read/write/read/write
+    vals.clear()
+    with lk.ReadTransaction(lock):
+        with lk.WriteTransaction(lock, release=write):
+            assert not vals['wrote']
+            with lk.ReadTransaction(lock):
+                assert not vals['wrote']
+                with lk.WriteTransaction(lock, release=write):
+                    assert not vals['wrote']
+                assert not vals['wrote']
+            assert not vals['wrote']
+        assert vals['wrote']
+
+
 def test_lock_debug_output(lock_path):
     host = socket.getfqdn()
 
-- 
GitLab