Skip to content
Snippets Groups Projects
Commit eb8fc4f3 authored by Todd Gamblin's avatar Todd Gamblin
Browse files

lock transactions: fix non-transactional writes

Lock transactions were actually writing *after* the lock was
released. The code was looking at the result of `release_write()` before
writing, then writing based on whether the lock was released.  This is
pretty obviously wrong.

- [x] Refactor `Lock` so that a release function can be passed to the
      `Lock` and called *only* when a lock is really released.

- [x] Refactor `LockTransaction` classes to use the release function
  instead of checking the return value of `release_read()` / `release_write()`
parent 779ac9fe
No related branches found
No related tags found
No related merge requests found
...@@ -95,10 +95,6 @@ def _lock(self, op, timeout=None): ...@@ -95,10 +95,6 @@ def _lock(self, op, timeout=None):
The lock is implemented as a spin lock using a nonblocking call The lock is implemented as a spin lock using a nonblocking call
to ``lockf()``. to ``lockf()``.
On acquiring an exclusive lock, the lock writes this process's
pid and host to the lock file, in case the holding process needs
to be killed later.
If the lock times out, it raises a ``LockError``. If the lock is If the lock times out, it raises a ``LockError``. If the lock is
successfully acquired, the total wait time and the number of attempts successfully acquired, the total wait time and the number of attempts
is returned. is returned.
...@@ -284,11 +280,19 @@ def acquire_write(self, timeout=None): ...@@ -284,11 +280,19 @@ def acquire_write(self, timeout=None):
self._writes += 1 self._writes += 1
return False return False
def release_read(self): def release_read(self, release_fn=None):
"""Releases a read lock. """Releases a read lock.
Returns True if the last recursive lock was released, False if Arguments:
there are still outstanding locks. release_fn (callable): function to call *before* the last recursive
lock (read or write) is released.
If the last recursive lock will be released, then this will call
release_fn and return its result (if provided), or return True
(if release_fn was not provided).
Otherwise, we are still nested inside some other lock, so do not
call the release_fn and, return False.
Does limited correctness checking: if a read lock is released Does limited correctness checking: if a read lock is released
when none are held, this will raise an assertion error. when none are held, this will raise an assertion error.
...@@ -300,18 +304,30 @@ def release_read(self): ...@@ -300,18 +304,30 @@ def release_read(self):
self._debug( self._debug(
'READ LOCK: {0.path}[{0._start}:{0._length}] [Released]' 'READ LOCK: {0.path}[{0._start}:{0._length}] [Released]'
.format(self)) .format(self))
result = True
if release_fn is not None:
result = release_fn()
self._unlock() # can raise LockError. self._unlock() # can raise LockError.
self._reads -= 1 self._reads -= 1
return True return result
else: else:
self._reads -= 1 self._reads -= 1
return False return False
def release_write(self): def release_write(self, release_fn=None):
"""Releases a write lock. """Releases a write lock.
Returns True if the last recursive lock was released, False if Arguments:
there are still outstanding locks. release_fn (callable): function to call before the last recursive
write is released.
If the last recursive *write* lock will be released, then this
will call release_fn and return its result (if provided), or
return True (if release_fn was not provided). Otherwise, we are
still nested inside some other write lock, so do not call the
release_fn, and return False.
Does limited correctness checking: if a read lock is released Does limited correctness checking: if a read lock is released
when none are held, this will raise an assertion error. when none are held, this will raise an assertion error.
...@@ -323,9 +339,16 @@ def release_write(self): ...@@ -323,9 +339,16 @@ def release_write(self):
self._debug( self._debug(
'WRITE LOCK: {0.path}[{0._start}:{0._length}] [Released]' 'WRITE LOCK: {0.path}[{0._start}:{0._length}] [Released]'
.format(self)) .format(self))
# we need to call release_fn before releasing the lock
result = True
if release_fn is not None:
result = release_fn()
self._unlock() # can raise LockError. self._unlock() # can raise LockError.
self._writes -= 1 self._writes -= 1
return True return result
else: else:
self._writes -= 1 self._writes -= 1
return False return False
...@@ -349,28 +372,36 @@ def _acquired_debug(self, lock_type, wait_time, nattempts): ...@@ -349,28 +372,36 @@ def _acquired_debug(self, lock_type, wait_time, nattempts):
class LockTransaction(object): class LockTransaction(object):
"""Simple nested transaction context manager that uses a file lock. """Simple nested transaction context manager that uses a file lock.
This class can trigger actions when the lock is acquired for the Arguments:
first time and released for the last. lock (Lock): underlying lock for this transaction to be accquired on
enter and released on exit
acquire (callable or contextmanager): function to be called after lock
is acquired, or contextmanager to enter after acquire and leave
before release.
release (callable): function to be called before release. If
``acquire`` is a contextmanager, this will be called *after*
exiting the nexted context and before the lock is released.
timeout (float): number of seconds to set for the timeout when
accquiring the lock (default no timeout)
If the ``acquire_fn`` returns a value, it is used as the return value for If the ``acquire_fn`` returns a value, it is used as the return value for
``__enter__``, allowing it to be passed as the ``as`` argument of a ``__enter__``, allowing it to be passed as the ``as`` argument of a
``with`` statement. ``with`` statement.
If ``acquire_fn`` returns a context manager, *its* ``__enter__`` function If ``acquire_fn`` returns a context manager, *its* ``__enter__`` function
will be called in ``__enter__`` after ``acquire_fn``, and its ``__exit__`` will be called after the lock is acquired, and its ``__exit__`` funciton
funciton will be called before ``release_fn`` in ``__exit__``, allowing you will be called before ``release_fn`` in ``__exit__``, allowing you to
to nest a context manager to be used along with the lock. nest a context manager inside this one.
Timeout for lock is customizable. Timeout for lock is customizable.
""" """
def __init__(self, lock, acquire_fn=None, release_fn=None, def __init__(self, lock, acquire=None, release=None, timeout=None):
timeout=None):
self._lock = lock self._lock = lock
self._timeout = timeout self._timeout = timeout
self._acquire_fn = acquire_fn self._acquire_fn = acquire
self._release_fn = release_fn self._release_fn = release
self._as = None self._as = None
def __enter__(self): def __enter__(self):
...@@ -383,13 +414,18 @@ def __enter__(self): ...@@ -383,13 +414,18 @@ def __enter__(self):
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
suppress = False suppress = False
if self._exit():
if self._as and hasattr(self._as, '__exit__'): def release_fn():
if self._as.__exit__(type, value, traceback): if self._release_fn is not None:
suppress = True return self._release_fn(type, value, traceback)
if self._release_fn:
if self._release_fn(type, value, traceback): if self._as and hasattr(self._as, '__exit__'):
suppress = True if self._as.__exit__(type, value, traceback):
suppress = True
if self._exit(release_fn):
suppress = True
return suppress return suppress
...@@ -398,8 +434,8 @@ class ReadTransaction(LockTransaction): ...@@ -398,8 +434,8 @@ class ReadTransaction(LockTransaction):
def _enter(self): def _enter(self):
return self._lock.acquire_read(self._timeout) return self._lock.acquire_read(self._timeout)
def _exit(self): def _exit(self, release_fn):
return self._lock.release_read() return self._lock.release_read(release_fn)
class WriteTransaction(LockTransaction): class WriteTransaction(LockTransaction):
...@@ -407,8 +443,8 @@ class WriteTransaction(LockTransaction): ...@@ -407,8 +443,8 @@ class WriteTransaction(LockTransaction):
def _enter(self): def _enter(self):
return self._lock.acquire_write(self._timeout) return self._lock.acquire_write(self._timeout)
def _exit(self): def _exit(self, release_fn):
return self._lock.release_write() return self._lock.release_write(release_fn)
class LockError(Exception): class LockError(Exception):
......
...@@ -332,11 +332,12 @@ def __init__(self, root, db_dir=None, upstream_dbs=None, ...@@ -332,11 +332,12 @@ def __init__(self, root, db_dir=None, upstream_dbs=None,
def write_transaction(self): def write_transaction(self):
"""Get a write lock context manager for use in a `with` block.""" """Get a write lock context manager for use in a `with` block."""
return WriteTransaction(self.lock, self._read, self._write) return WriteTransaction(
self.lock, acquire=self._read, release=self._write)
def read_transaction(self): def read_transaction(self):
"""Get a read lock context manager for use in a `with` block.""" """Get a read lock context manager for use in a `with` block."""
return ReadTransaction(self.lock, self._read) return ReadTransaction(self.lock, acquire=self._read)
def prefix_lock(self, spec): def prefix_lock(self, spec):
"""Get a lock on a particular spec's installation directory. """Get a lock on a particular spec's installation directory.
...@@ -624,7 +625,7 @@ def _read_suppress_error(): ...@@ -624,7 +625,7 @@ def _read_suppress_error():
self._data = {} self._data = {}
transaction = WriteTransaction( transaction = WriteTransaction(
self.lock, _read_suppress_error, self._write self.lock, acquire=_read_suppress_error, release=self._write
) )
with transaction: with transaction:
......
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
actually on a shared filesystem. actually on a shared filesystem.
""" """
import collections
import os import os
import socket import socket
import shutil import shutil
...@@ -776,189 +777,258 @@ def p3(barrier): ...@@ -776,189 +777,258 @@ def p3(barrier):
multiproc_test(p1, p2, p3) multiproc_test(p1, p2, p3)
def test_transaction(lock_path): class AssertLock(lk.Lock):
"""Test lock class that marks acquire/release events."""
def __init__(self, lock_path, vals):
super(AssertLock, self).__init__(lock_path)
self.vals = vals
def acquire_read(self, timeout=None):
self.assert_acquire_read()
result = super(AssertLock, self).acquire_read(timeout)
self.vals['acquired_read'] = True
return result
def acquire_write(self, timeout=None):
self.assert_acquire_write()
result = super(AssertLock, self).acquire_write(timeout)
self.vals['acquired_write'] = True
return result
def release_read(self, release_fn=None):
self.assert_release_read()
result = super(AssertLock, self).release_read(release_fn)
self.vals['released_read'] = True
return result
def release_write(self, release_fn=None):
self.assert_release_write()
result = super(AssertLock, self).release_write(release_fn)
self.vals['released_write'] = True
return result
@pytest.mark.parametrize(
"transaction,type",
[(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")]
)
def test_transaction(lock_path, transaction, type):
class MockLock(AssertLock):
def assert_acquire_read(self):
assert not vals['entered_fn']
assert not vals['exited_fn']
def assert_release_read(self):
assert vals['entered_fn']
assert not vals['exited_fn']
def assert_acquire_write(self):
assert not vals['entered_fn']
assert not vals['exited_fn']
def assert_release_write(self):
assert vals['entered_fn']
assert not vals['exited_fn']
def enter_fn(): def enter_fn():
vals['entered'] = True # assert enter_fn is called while lock is held
assert vals['acquired_%s' % type]
vals['entered_fn'] = True
def exit_fn(t, v, tb): def exit_fn(t, v, tb):
vals['exited'] = True # assert exit_fn is called while lock is held
assert not vals['released_%s' % type]
vals['exited_fn'] = True
vals['exception'] = (t or v or tb) vals['exception'] = (t or v or tb)
lock = lk.Lock(lock_path) vals = collections.defaultdict(lambda: False)
vals = {'entered': False, 'exited': False, 'exception': False} lock = MockLock(lock_path, vals)
with lk.ReadTransaction(lock, enter_fn, exit_fn):
pass with transaction(lock, acquire=enter_fn, release=exit_fn):
assert vals['acquired_%s' % type]
assert not vals['released_%s' % type]
assert vals['entered'] assert vals['entered_fn']
assert vals['exited'] assert vals['exited_fn']
assert vals['acquired_%s' % type]
assert vals['released_%s' % type]
assert not vals['exception'] assert not vals['exception']
vals = {'entered': False, 'exited': False, 'exception': False}
with lk.WriteTransaction(lock, enter_fn, exit_fn):
pass
assert vals['entered'] @pytest.mark.parametrize(
assert vals['exited'] "transaction,type",
assert not vals['exception'] [(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")]
)
def test_transaction_with_exception(lock_path, transaction, type):
class MockLock(AssertLock):
def assert_acquire_read(self):
assert not vals['entered_fn']
assert not vals['exited_fn']
def assert_release_read(self):
assert vals['entered_fn']
assert not vals['exited_fn']
def assert_acquire_write(self):
assert not vals['entered_fn']
assert not vals['exited_fn']
def assert_release_write(self):
assert vals['entered_fn']
assert not vals['exited_fn']
def test_transaction_with_exception(lock_path):
def enter_fn(): def enter_fn():
vals['entered'] = True assert vals['acquired_%s' % type]
vals['entered_fn'] = True
def exit_fn(t, v, tb): def exit_fn(t, v, tb):
vals['exited'] = True assert not vals['released_%s' % type]
vals['exited_fn'] = True
vals['exception'] = (t or v or tb) vals['exception'] = (t or v or tb)
return exit_result
lock = lk.Lock(lock_path) exit_result = False
vals = collections.defaultdict(lambda: False)
def do_read_with_exception(): lock = MockLock(lock_path, vals)
with lk.ReadTransaction(lock, enter_fn, exit_fn):
raise Exception()
def do_write_with_exception():
with lk.WriteTransaction(lock, enter_fn, exit_fn):
raise Exception()
vals = {'entered': False, 'exited': False, 'exception': False}
with pytest.raises(Exception): with pytest.raises(Exception):
do_read_with_exception() with transaction(lock, acquire=enter_fn, release=exit_fn):
assert vals['entered'] raise Exception()
assert vals['exited']
assert vals['exception']
vals = {'entered': False, 'exited': False, 'exception': False} assert vals['entered_fn']
with pytest.raises(Exception): assert vals['exited_fn']
do_write_with_exception()
assert vals['entered']
assert vals['exited']
assert vals['exception'] assert vals['exception']
# test suppression of exceptions from exit_fn
exit_result = True
vals.clear()
def test_transaction_with_context_manager(lock_path): # should not raise now.
class TestContextManager(object): with transaction(lock, acquire=enter_fn, release=exit_fn):
raise Exception()
def __enter__(self):
vals['entered'] = True
def __exit__(self, t, v, tb):
vals['exited'] = True
vals['exception'] = (t or v or tb)
def exit_fn(t, v, tb):
vals['exited_fn'] = True
vals['exception_fn'] = (t or v or tb)
lock = lk.Lock(lock_path)
vals = {'entered': False, 'exited': False, 'exited_fn': False,
'exception': False, 'exception_fn': False}
with lk.ReadTransaction(lock, TestContextManager, exit_fn):
pass
assert vals['entered'] assert vals['entered_fn']
assert vals['exited']
assert not vals['exception']
assert vals['exited_fn'] assert vals['exited_fn']
assert not vals['exception_fn'] assert vals['exception']
vals = {'entered': False, 'exited': False, 'exited_fn': False,
'exception': False, 'exception_fn': False}
with lk.ReadTransaction(lock, TestContextManager):
pass
assert vals['entered']
assert vals['exited']
assert not vals['exception']
assert not vals['exited_fn']
assert not vals['exception_fn']
vals = {'entered': False, 'exited': False, 'exited_fn': False,
'exception': False, 'exception_fn': False}
with lk.WriteTransaction(lock, TestContextManager, exit_fn):
pass
assert vals['entered'] @pytest.mark.parametrize(
assert vals['exited'] "transaction,type",
assert not vals['exception'] [(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")]
assert vals['exited_fn'] )
assert not vals['exception_fn'] def test_transaction_with_context_manager(lock_path, transaction, type):
class MockLock(AssertLock):
def assert_acquire_read(self):
assert not vals['entered_ctx']
assert not vals['exited_ctx']
vals = {'entered': False, 'exited': False, 'exited_fn': False, def assert_release_read(self):
'exception': False, 'exception_fn': False} assert vals['entered_ctx']
with lk.WriteTransaction(lock, TestContextManager): assert vals['exited_ctx']
pass
assert vals['entered'] def assert_acquire_write(self):
assert vals['exited'] assert not vals['entered_ctx']
assert not vals['exception'] assert not vals['exited_ctx']
assert not vals['exited_fn']
assert not vals['exception_fn']
def assert_release_write(self):
assert vals['entered_ctx']
assert vals['exited_ctx']
def test_transaction_with_context_manager_and_exception(lock_path):
class TestContextManager(object): class TestContextManager(object):
def __enter__(self): def __enter__(self):
vals['entered'] = True vals['entered_ctx'] = True
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
vals['exited'] = True assert not vals['released_%s' % type]
vals['exception'] = (t or v or tb) vals['exited_ctx'] = True
vals['exception_ctx'] = (t or v or tb)
return exit_ctx_result
def exit_fn(t, v, tb): def exit_fn(t, v, tb):
assert not vals['released_%s' % type]
vals['exited_fn'] = True vals['exited_fn'] = True
vals['exception_fn'] = (t or v or tb) vals['exception_fn'] = (t or v or tb)
return exit_fn_result
lock = lk.Lock(lock_path) exit_fn_result, exit_ctx_result = False, False
vals = collections.defaultdict(lambda: False)
def do_read_with_exception(exit_fn): lock = MockLock(lock_path, vals)
with lk.ReadTransaction(lock, TestContextManager, exit_fn):
raise Exception()
def do_write_with_exception(exit_fn): with transaction(lock, acquire=TestContextManager, release=exit_fn):
with lk.WriteTransaction(lock, TestContextManager, exit_fn): pass
raise Exception()
vals = {'entered': False, 'exited': False, 'exited_fn': False, assert vals['entered_ctx']
'exception': False, 'exception_fn': False} assert vals['exited_ctx']
with pytest.raises(Exception):
do_read_with_exception(exit_fn)
assert vals['entered']
assert vals['exited']
assert vals['exception']
assert vals['exited_fn'] assert vals['exited_fn']
assert vals['exception_fn'] assert not vals['exception_ctx']
vals = {'entered': False, 'exited': False, 'exited_fn': False,
'exception': False, 'exception_fn': False}
with pytest.raises(Exception):
do_read_with_exception(None)
assert vals['entered']
assert vals['exited']
assert vals['exception']
assert not vals['exited_fn']
assert not vals['exception_fn'] assert not vals['exception_fn']
vals = {'entered': False, 'exited': False, 'exited_fn': False, vals.clear()
'exception': False, 'exception_fn': False} with transaction(lock, acquire=TestContextManager):
with pytest.raises(Exception): pass
do_write_with_exception(exit_fn)
assert vals['entered']
assert vals['exited']
assert vals['exception']
assert vals['exited_fn']
assert vals['exception_fn']
vals = {'entered': False, 'exited': False, 'exited_fn': False, assert vals['entered_ctx']
'exception': False, 'exception_fn': False} assert vals['exited_ctx']
with pytest.raises(Exception):
do_write_with_exception(None)
assert vals['entered']
assert vals['exited']
assert vals['exception']
assert not vals['exited_fn'] assert not vals['exited_fn']
assert not vals['exception_ctx']
assert not vals['exception_fn'] assert not vals['exception_fn']
# below are tests for exceptions with and without suppression
def assert_ctx_and_fn_exception(raises=True):
vals.clear()
if raises:
with pytest.raises(Exception):
with transaction(
lock, acquire=TestContextManager, release=exit_fn):
raise Exception()
else:
with transaction(
lock, acquire=TestContextManager, release=exit_fn):
raise Exception()
assert vals['entered_ctx']
assert vals['exited_ctx']
assert vals['exited_fn']
assert vals['exception_ctx']
assert vals['exception_fn']
def assert_only_ctx_exception(raises=True):
vals.clear()
if raises:
with pytest.raises(Exception):
with transaction(lock, acquire=TestContextManager):
raise Exception()
else:
with transaction(lock, acquire=TestContextManager):
raise Exception()
assert vals['entered_ctx']
assert vals['exited_ctx']
assert not vals['exited_fn']
assert vals['exception_ctx']
assert not vals['exception_fn']
# no suppression
assert_ctx_and_fn_exception(raises=True)
assert_only_ctx_exception(raises=True)
# suppress exception only in function
exit_fn_result, exit_ctx_result = True, False
assert_ctx_and_fn_exception(raises=False)
assert_only_ctx_exception(raises=True)
# suppress exception only in context
exit_fn_result, exit_ctx_result = False, True
assert_ctx_and_fn_exception(raises=False)
assert_only_ctx_exception(raises=False)
# suppress exception in function and context
exit_fn_result, exit_ctx_result = True, True
assert_ctx_and_fn_exception(raises=False)
assert_only_ctx_exception(raises=False)
def test_lock_debug_output(lock_path): def test_lock_debug_output(lock_path):
host = socket.getfqdn() host = socket.getfqdn()
......
...@@ -107,7 +107,8 @@ def read_transaction(self, key): ...@@ -107,7 +107,8 @@ def read_transaction(self, key):
""" """
return ReadTransaction( return ReadTransaction(
self._get_lock(key), lambda: open(self.cache_path(key))) self._get_lock(key), acquire=lambda: open(self.cache_path(key))
)
def write_transaction(self, key): def write_transaction(self, key):
"""Get a write transaction on a file cache item. """Get a write transaction on a file cache item.
...@@ -117,6 +118,10 @@ def write_transaction(self, key): ...@@ -117,6 +118,10 @@ def write_transaction(self, key):
moves the file into place on top of the old file atomically. moves the file into place on top of the old file atomically.
""" """
# TODO: this nested context manager adds a lot of complexity and
# TODO: is pretty hard to reason about in llnl.util.lock. At some
# TODO: point we should just replace it with functions and simplify
# TODO: the locking code.
class WriteContextManager(object): class WriteContextManager(object):
def __enter__(cm): # noqa def __enter__(cm): # noqa
...@@ -142,7 +147,8 @@ def __exit__(cm, type, value, traceback): # noqa ...@@ -142,7 +147,8 @@ def __exit__(cm, type, value, traceback): # noqa
else: else:
os.rename(cm.tmp_filename, cm.orig_filename) os.rename(cm.tmp_filename, cm.orig_filename)
return WriteTransaction(self._get_lock(key), WriteContextManager) return WriteTransaction(
self._get_lock(key), acquire=WriteContextManager)
def mtime(self, key): def mtime(self, key):
"""Return modification time of cache file, or 0 if it does not exist. """Return modification time of cache file, or 0 if it does not exist.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment