From 6c9467e8c6d5fffa6d9412d5d36ff2c91e33ea21 Mon Sep 17 00:00:00 2001
From: Todd Gamblin <tgamblin@llnl.gov>
Date: Sat, 21 Dec 2019 16:31:28 -0800
Subject: [PATCH] lock transactions: avoid redundant reading in write
 transactions

Our `LockTransaction` class was reading overly aggressively.  In cases
like this:

```
1  with spack.store.db.read_transaction():
2    with spack.store.db.write_transaction():
3      ...
```

The `ReadTransaction` on line 1 would read in the DB, but the
WriteTransaction on line 2 would read in the DB *again*, even though we
had a read lock the whole time.  `WriteTransaction`s were only
considering nested writes to decide when to read, but they didn't know
when we already had a read lock.

- [x] `Lock.acquire_write()` return `False` in cases where we already had
       a read lock.
---
 lib/spack/llnl/util/lock.py            |  7 +++-
 lib/spack/spack/test/llnl/util/lock.py | 56 ++++++++++++++++++++++++++
 2 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/lib/spack/llnl/util/lock.py b/lib/spack/llnl/util/lock.py
index 86a45e2d7c..3a58093491 100644
--- a/lib/spack/llnl/util/lock.py
+++ b/lib/spack/llnl/util/lock.py
@@ -275,7 +275,12 @@ def acquire_write(self, timeout=None):
             wait_time, nattempts = self._lock(fcntl.LOCK_EX, timeout=timeout)
             self._acquired_debug('WRITE LOCK', wait_time, nattempts)
             self._writes += 1
-            return True
+
+            # return True only if we weren't nested in a read lock.
+            # TODO: we may need to return two values: whether we got
+            # the write lock, and whether this is acquiring a read OR
+            # write lock for the first time. Now it returns the latter.
+            return self._reads == 0
         else:
             self._writes += 1
             return False
diff --git a/lib/spack/spack/test/llnl/util/lock.py b/lib/spack/spack/test/llnl/util/lock.py
index 2b0892a25e..ca879cdc0b 100644
--- a/lib/spack/spack/test/llnl/util/lock.py
+++ b/lib/spack/spack/test/llnl/util/lock.py
@@ -1087,6 +1087,62 @@ def write(t, v, tb):
         assert vals['wrote']
 
 
+def test_nested_reads(lock_path):
+    """Ensure that write transactions won't re-read data."""
+
+    def read():
+        vals['read'] += 1
+
+    vals = collections.defaultdict(lambda: 0)
+    lock = AssertLock(lock_path, vals)
+
+    # read/read
+    vals.clear()
+    assert vals['read'] == 0
+    with lk.ReadTransaction(lock, acquire=read):
+        assert vals['read'] == 1
+        with lk.ReadTransaction(lock, acquire=read):
+            assert vals['read'] == 1
+
+    # write/write
+    vals.clear()
+    assert vals['read'] == 0
+    with lk.WriteTransaction(lock, acquire=read):
+        assert vals['read'] == 1
+        with lk.WriteTransaction(lock, acquire=read):
+            assert vals['read'] == 1
+
+    # read/write
+    vals.clear()
+    assert vals['read'] == 0
+    with lk.ReadTransaction(lock, acquire=read):
+        assert vals['read'] == 1
+        with lk.WriteTransaction(lock, acquire=read):
+            assert vals['read'] == 1
+
+    # write/read/write
+    vals.clear()
+    assert vals['read'] == 0
+    with lk.WriteTransaction(lock, acquire=read):
+        assert vals['read'] == 1
+        with lk.ReadTransaction(lock, acquire=read):
+            assert vals['read'] == 1
+            with lk.WriteTransaction(lock, acquire=read):
+                assert vals['read'] == 1
+
+    # read/write/read/write
+    vals.clear()
+    assert vals['read'] == 0
+    with lk.ReadTransaction(lock, acquire=read):
+        assert vals['read'] == 1
+        with lk.WriteTransaction(lock, acquire=read):
+            assert vals['read'] == 1
+            with lk.ReadTransaction(lock, acquire=read):
+                assert vals['read'] == 1
+                with lk.WriteTransaction(lock, acquire=read):
+                    assert vals['read'] == 1
+
+
 def test_lock_debug_output(lock_path):
     host = socket.getfqdn()
 
-- 
GitLab