[Neo-report] r2531 vincent - in /trunk/neo: client/ client/handlers/ tests/client/

nobody at svn.erp5.org nobody at svn.erp5.org
Tue Dec 14 16:56:50 CET 2010


Author: vincent
Date: Tue Dec 14 16:56:50 2010
New Revision: 2531

Log:
Implement revision-aware caching.

Modified:
    trunk/neo/client/app.py
    trunk/neo/client/handlers/master.py
    trunk/neo/tests/client/testClientApp.py
    trunk/neo/tests/client/testMasterHandler.py

Modified: trunk/neo/client/app.py
==============================================================================
--- trunk/neo/client/app.py [iso-8859-1] (original)
+++ trunk/neo/client/app.py [iso-8859-1] Tue Dec 14 16:56:50 2010
@@ -43,7 +43,7 @@ from neo.client.handlers import storage,
 from neo.dispatcher import Dispatcher, ForgottenPacket
 from neo.client.poll import ThreadedPoll, psThreadedPoll
 from neo.client.iterator import Iterator
-from neo.client.mq import MQ
+from neo.client.mq import MQ, MQIndex
 from neo.client.pool import ConnectionPool
 from neo.util import u64, parseMasterList
 from neo.profiling import profiler_decorator, PROFILING_ENABLED
@@ -119,6 +119,139 @@ class ThreadContext(object):
             'last_transaction': None,
         }
 
+class RevisionIndex(MQIndex):
+    """
+    This cache index allows accessing a specifig revision of a cached object.
+    It requires cache key to be a 2-tuple, composed of oid and revision.
+
+    Note: it is expected that rather few revisions are held in cache, with few
+    lookups for old revisions, so they are held in a simple sorted list
+    Note2: all methods here must be called with cache lock acquired.
+    """
+    def __init__(self):
+        # key: oid
+        # value: tid list, from highest to lowest
+        self._oid_dict = {}
+        # key: oid
+        # value: tid list, from lowest to highest
+        self._invalidated = {}
+
+    def clear(self):
+        self._oid_dict.clear()
+        self._invalidated.clear()
+
+    def remove(self, key):
+        oid_dict = self._oid_dict
+        oid, tid = key
+        tid_list = oid_dict[oid]
+        tid_list.remove(tid)
+        if not tid_list:
+            # No more serial known for this object, drop entirely
+            del oid_dict[oid]
+            self._invalidated.pop(oid, None)
+
+    def add(self, key):
+        oid_dict = self._oid_dict
+        oid, tid = key
+        try:
+            serial_list = oid_dict[oid]
+        except KeyError:
+            serial_list = oid_dict[oid] = []
+        else:
+            assert tid not in serial_list
+        if not(serial_list) or tid > serial_list[0]:
+            serial_list.insert(0, tid)
+        else:
+            serial_list.insert(0, tid)
+            serial_list.sort(reverse=True)
+        invalidated = self._invalidated
+        try:
+            tid_list = invalidated[oid]
+        except KeyError:
+            pass
+        else:
+            try:
+                tid_list.remove(tid)
+            except ValueError:
+                pass
+            else:
+                if not tid_list:
+                    del invalidated[oid]
+
+    def invalidate(self, oid_list, tid):
+        """
+        Mark object invalidated by given transaction.
+        Must be called with increasing TID values (which is standard for
+        ZODB).
+        """
+        invalidated = self._invalidated
+        oid_dict = self._oid_dict
+        for oid in (x for x in oid_list if x in oid_dict):
+            try:
+                tid_list = invalidated[oid]
+            except KeyError:
+                tid_list = invalidated[oid] = []
+            assert not tid_list or tid > tid_list[-1], (dump(oid), dump(tid),
+                dump(tid_list[-1]))
+            tid_list.append(tid)
+
+    def getSerialBefore(self, oid, tid):
+        """
+        Get the first tid in cache which value is lower that given tid.
+        """
+        # WARNING: return-intensive to save on indentation
+        oid_list = self._oid_dict.get(oid)
+        if oid_list is None:
+            # Unknown oid
+            return None
+        for result in oid_list:
+            if result < tid:
+                # Candidate found
+                break
+        else:
+            # No candidate in cache.
+            return None
+        # Check if there is a chance that an intermediate revision would
+        # exist, while missing from cache.
+        try:
+            inv_tid_list = self._invalidated[oid]
+        except KeyError:
+            return result
+        # Remember: inv_tid_list is sorted in ascending order.
+        for inv_tid in inv_tid_list:
+            if tid < inv_tid:
+                # We don't care about invalidations past requested TID.
+                break
+            elif result < inv_tid < tid:
+                # An invalidation was received between candidate revision,
+                # and before requested TID: there is a matching revision we
+                # don't know of, so we cannot answer.
+                return None
+        return result
+
+    def getLatestSerial(self, oid):
+        """
+        Get the latest tid for given object.
+        """
+        result = self._oid_dict.get(oid)
+        if result is not None:
+            result = result[0]
+            try:
+                tid_list = self._invalidated[oid]
+            except KeyError:
+                pass
+            else:
+                if result < tid_list[-1]:
+                    # An invalidation happened from a transaction later than our
+                    # most recent view of this object, so we cannot answer.
+                    result = None
+        return result
+
+    def getSerialList(self, oid):
+        """
+        Get the list of all serials cache knows about for given object.
+        """
+        return self._oid_dict.get(oid, [])[:]
 
 class Application(object):
     """The client node application."""
@@ -147,6 +280,8 @@ class Application(object):
         # no self-assigned UUID, primary master will supply us one
         self.uuid = None
         self.mq_cache = MQ()
+        self.cache_revision_index = RevisionIndex()
+        self.mq_cache.addIndex(self.cache_revision_index)
         self.new_oid_list = []
         self.last_oid = '\0' * 8
         self.storage_event_handler = storage.StorageEventHandler(self)
@@ -429,7 +564,7 @@ class Application(object):
         return int(u64(self.last_oid))
 
     @profiler_decorator
-    def _load(self, oid, serial=None, tid=None, cache=0):
+    def _load(self, oid, serial=None, tid=None):
         """
         Internal method which manage load, loadSerial and loadBefore.
         OID and TID (serial) parameters are expected packed.
@@ -441,8 +576,6 @@ class Application(object):
         tid
             If given, the excluded upper bound serial at which OID is desired.
             serial should be None.
-        cache
-            Store data in cache for future lookups.
 
         Return value: (3-tuple)
         - Object data (None if object creation was undone).
@@ -471,21 +604,19 @@ class Application(object):
             if not self.local_var.barrier_done:
                 self.invalidationBarrier()
                 self.local_var.barrier_done = True
-            if cache:
-                try:
-                    result = self._loadFromCache(oid, serial, tid)
-                except KeyError:
-                    pass
-                else:
-                    return result
+            try:
+                result = self._loadFromCache(oid, serial, tid)
+            except KeyError:
+                pass
+            else:
+                return result
             data, start_serial, end_serial = self._loadFromStorage(oid, serial,
                 tid)
-            if cache:
-                self._cache_lock_acquire()
-                try:
-                    self.mq_cache[oid] = start_serial, data
-                finally:
-                    self._cache_lock_release()
+            self._cache_lock_acquire()
+            try:
+                self.mq_cache[(oid, start_serial)] = data, end_serial
+            finally:
+                self._cache_lock_release()
             if data == '':
                 raise NEOStorageCreationUndoneError(dump(oid))
             return data, start_serial, end_serial
@@ -555,16 +686,25 @@ class Application(object):
         """
         self._cache_lock_acquire()
         try:
-            tid, data = self.mq_cache[oid]
-            neo.logging.debug('load oid %s is cached', dump(oid))
-            return (data, tid, None)
+            if at_tid is not None:
+                tid = at_tid
+            elif before_tid is not None:
+                tid = self.cache_revision_index.getSerialBefore(oid,
+                    before_tid)
+            else:
+                tid = self.cache_revision_index.getLatestSerial(oid)
+            if tid is None:
+                raise KeyError
+            # Raises KeyError on miss
+            data, next_tid = self.mq_cache[(oid, tid)]
+            return (data, tid, next_tid)
         finally:
             self._cache_lock_release()
 
     @profiler_decorator
     def load(self, oid, version=None):
         """Load an object for a given oid."""
-        result = self._load(oid, cache=1)[:2]
+        result = self._load(oid)[:2]
         # Start a network barrier, so we get all invalidations *after* we
         # received data. This ensures we get any invalidation message that
         # would have been about the version we loaded.
@@ -578,7 +718,6 @@ class Application(object):
     @profiler_decorator
     def loadSerial(self, oid, serial):
         """Load an object for a given oid and serial."""
-        # Do not try in cache as it manages only up-to-date object
         neo.logging.debug('loading %s at %s', dump(oid), dump(serial))
         return self._load(oid, serial=serial)[0]
 
@@ -586,7 +725,6 @@ class Application(object):
     @profiler_decorator
     def loadBefore(self, oid, tid):
         """Load an object for a given oid before tid committed."""
-        # Do not try in cache as it manages only up-to-date object
         neo.logging.debug('loading %s before %s', dump(oid), dump(tid))
         return self._load(oid, tid=tid)
 
@@ -878,12 +1016,30 @@ class Application(object):
             self._cache_lock_acquire()
             try:
                 mq_cache = self.mq_cache
+                update = mq_cache.update
+                def updateNextSerial(value):
+                    data, next_tid = value
+                    assert next_tid is None, (dump(oid), dump(base_tid),
+                        dump(next_tid))
+                    return (data, tid)
+                get_baseTID = local_var.object_serial_dict.get
                 for oid, data in local_var.data_dict.iteritems():
+                    if data is None:
+                        # this is just a remain of
+                        # checkCurrentSerialInTransaction call, ignore (no data
+                        # was modified).
+                        continue
+                    # Update ex-latest value in cache
+                    base_tid = get_baseTID(oid)
+                    try:
+                        update((oid, base_tid), updateNextSerial)
+                    except KeyError:
+                        pass
                     if data == '':
-                        if oid in mq_cache:
-                            del mq_cache[oid]
+                        self.cache_revision_index.invalidate([oid], tid)
                     else:
-                        mq_cache[oid] = tid, data
+                        # Store in cache with no next_tid
+                        mq_cache[(oid, tid)] = (data, None)
             finally:
                 self._cache_lock_release()
             local_var.clear()
@@ -1234,6 +1390,15 @@ class Application(object):
         if tid == ZERO_TID:
             raise NEOStorageError('Invalid pack time')
         self._askPrimary(Packets.AskPack(tid))
+        # XXX: this is only needed to make ZODB unit tests pass.
+        # It should not be otherwise required (clients should be free to load
+        # old data as long as it is available in cache, event if it was pruned
+        # by a pack), so don't bother invalidating on other clients.
+        self._cache_lock_acquire()
+        try:
+            self.mq_cache.clear()
+        finally:
+            self._cache_lock_release()
 
     def getLastTID(self, oid):
         return self._load(oid)[1]

Modified: trunk/neo/client/handlers/master.py
==============================================================================
--- trunk/neo/client/handlers/master.py [iso-8859-1] (original)
+++ trunk/neo/client/handlers/master.py [iso-8859-1] Tue Dec 14 16:56:50 2010
@@ -123,10 +123,7 @@ class PrimaryNotificationsHandler(BaseHa
         app._cache_lock_acquire()
         try:
             # ZODB required a dict with oid as key, so create it
-            mq_cache = app.mq_cache
-            for oid in oid_list:
-                if oid in mq_cache:
-                    del mq_cache[oid]
+            app.cache_revision_index.invalidate(oid_list, tid)
             db = app.getDB()
             if db is not None:
                 db.invalidate(tid, dict.fromkeys(oid_list, tid))

Modified: trunk/neo/tests/client/testClientApp.py
==============================================================================
--- trunk/neo/tests/client/testClientApp.py [iso-8859-1] (original)
+++ trunk/neo/tests/client/testClientApp.py [iso-8859-1] Tue Dec 14 16:56:50 2010
@@ -21,7 +21,7 @@ from cPickle import dumps
 from mock import Mock, ReturnValues
 from ZODB.POSException import StorageTransactionError, UndoError, ConflictError
 from neo.tests import NeoUnitTestBase
-from neo.client.app import Application
+from neo.client.app import Application, RevisionIndex
 from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
 from neo.client.exception import NEOStorageDoesNotExistError
 from neo.protocol import Packet, Packets, Errors, INVALID_TID, INVALID_SERIAL
@@ -208,7 +208,8 @@ class ClientApplicationTests(NeoUnitTest
         tid2 = self.makeTID(2)
         an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None)
         # connection to SN close
-        self.assertTrue(oid not in mq)
+        self.assertTrue((oid, tid1) not in mq)
+        self.assertTrue((oid, tid2) not in mq)
         packet = Errors.OidNotFound('')
         packet.setId(0)
         cell = Mock({ 'getUUID': '\x00' * 16})
@@ -224,7 +225,8 @@ class ClientApplicationTests(NeoUnitTest
         self.checkAskObject(conn)
         Application._waitMessage = _waitMessage
         # object not found in NEO -> NEOStorageNotFoundError
-        self.assertTrue(oid not in mq)
+        self.assertTrue((oid, tid1) not in mq)
+        self.assertTrue((oid, tid2) not in mq)
         packet = Errors.OidNotFound('')
         packet.setId(0)
         cell = Mock({ 'getUUID': '\x00' * 16})
@@ -254,7 +256,7 @@ class ClientApplicationTests(NeoUnitTest
         result = app.load(oid)
         self.assertEquals(result, ('OBJ', tid1))
         self.checkAskObject(conn)
-        self.assertTrue(oid in mq)
+        self.assertTrue((oid, tid1) in mq)
         # object is now cached, try to reload it
         conn = Mock({
             'getAddress': ('127.0.0.1', 0),
@@ -272,7 +274,8 @@ class ClientApplicationTests(NeoUnitTest
         tid1 = self.makeTID(1)
         tid2 = self.makeTID(2)
         # object not found in NEO -> NEOStorageNotFoundError
-        self.assertTrue(oid not in mq)
+        self.assertTrue((oid, tid1) not in mq)
+        self.assertTrue((oid, tid2) not in mq)
         packet = Errors.OidNotFound('')
         packet.setId(0)
         cell = Mock({ 'getUUID': '\x00' * 16})
@@ -285,10 +288,10 @@ class ClientApplicationTests(NeoUnitTest
         self.assertRaises(NEOStorageNotFoundError, app.loadSerial, oid, tid2)
         self.checkAskObject(conn)
         # object should not have been cached
-        self.assertFalse(oid in mq)
+        self.assertFalse((oid, tid2) in mq)
         # now a cached version ewxists but should not be hit
-        mq.store(oid, (tid2, 'WRONG'))
-        self.assertTrue(oid in mq)
+        mq.store((oid, tid2), ('WRONG', None))
+        self.assertTrue((oid, tid2) in mq)
         another_object = (1, oid, tid2, INVALID_SERIAL, 0,
             makeChecksum('RIGHT'), 'RIGHT', None)
         packet = Packets.AnswerObject(*another_object[1:])
@@ -302,7 +305,7 @@ class ClientApplicationTests(NeoUnitTest
         result = app.loadSerial(oid, tid1)
         self.assertEquals(result, 'RIGHT')
         self.checkAskObject(conn)
-        self.assertTrue(oid in mq)
+        self.assertTrue((oid, tid2) in mq)
 
     def test_loadBefore(self):
         app = self.getApp()
@@ -313,7 +316,8 @@ class ClientApplicationTests(NeoUnitTest
         tid2 = self.makeTID(2)
         tid3 = self.makeTID(3)
         # object not found in NEO -> NEOStorageDoesNotExistError
-        self.assertTrue(oid not in mq)
+        self.assertTrue((oid, tid1) not in mq)
+        self.assertTrue((oid, tid2) not in mq)
         packet = Errors.OidDoesNotExist('')
         packet.setId(0)
         cell = Mock({ 'getUUID': '\x00' * 16})
@@ -337,11 +341,12 @@ class ClientApplicationTests(NeoUnitTest
         app.local_var.asked_object = an_object[:-1]
         self.assertRaises(NEOStorageError, app.loadBefore, oid, tid1)
         # object should not have been cached
-        self.assertFalse(oid in mq)
+        self.assertFalse((oid, tid1) in mq)
         # as for loadSerial, the object is cached but should be loaded from db
-        mq.store(oid, (tid1, 'WRONG'))
-        self.assertTrue(oid in mq)
-        another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'),
+        mq.store((oid, tid1), ('WRONG', tid2))
+        self.assertTrue((oid, tid1) in mq)
+        app.cache_revision_index.invalidate([oid], tid2)
+        another_object = (1, oid, tid2, tid3, 0, makeChecksum('RIGHT'),
             'RIGHT', None)
         packet = Packets.AnswerObject(*another_object[1:])
         packet.setId(0)
@@ -352,9 +357,9 @@ class ClientApplicationTests(NeoUnitTest
         app.cp = Mock({ 'getConnForCell' : conn})
         app.local_var.asked_object = another_object
         result = app.loadBefore(oid, tid3)
-        self.assertEquals(result, ('RIGHT', tid1, tid2))
+        self.assertEquals(result, ('RIGHT', tid2, tid3))
         self.checkAskObject(conn)
-        self.assertTrue(oid in mq)
+        self.assertTrue((oid, tid1) in mq)
 
     def test_tpc_begin(self):
         app = self.getApp()
@@ -1156,6 +1161,90 @@ class ClientApplicationTests(NeoUnitTest
         self.assertEqual(marker[0].getType(), Packets.AskPack)
         # XXX: how to validate packet content ?
 
+    def test_RevisionIndex_1(self):
+        # Test add, getLatestSerial, getSerialList and clear
+        # without invalidations
+        oid1 = self.getOID(1)
+        oid2 = self.getOID(2)
+        tid1 = self.getOID(1)
+        tid2 = self.getOID(2)
+        tid3 = self.getOID(3)
+        ri = RevisionIndex()
+        # index is empty
+        self.assertEqual(ri.getSerialList(oid1), [])
+        ri.add((oid1, tid1))
+        # now, it knows oid1 at tid1
+        self.assertEqual(ri.getLatestSerial(oid1), tid1)
+        self.assertEqual(ri.getSerialList(oid1), [tid1])
+        self.assertEqual(ri.getSerialList(oid2), [])
+        ri.add((oid1, tid2))
+        # and at tid2
+        self.assertEqual(ri.getLatestSerial(oid1), tid2)
+        self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
+        ri.remove((oid1, tid1))
+        # oid1 at tid1 was pruned from cache
+        self.assertEqual(ri.getLatestSerial(oid1), tid2)
+        self.assertEqual(ri.getSerialList(oid1), [tid2])
+        ri.remove((oid1, tid2))
+        # oid1 is completely priuned from cache
+        self.assertEqual(ri.getLatestSerial(oid1), None)
+        self.assertEqual(ri.getSerialList(oid1), [])
+        ri.add((oid1, tid2))
+        ri.add((oid1, tid1))
+        # oid1 is populated, but in non-chronological order, check index
+        # still answers consistent result.
+        self.assertEqual(ri.getLatestSerial(oid1), tid2)
+        self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
+        ri.add((oid2, tid3))
+        # which is not affected by the addition of oid2 at tid3
+        self.assertEqual(ri.getLatestSerial(oid1), tid2)
+        self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
+        ri.clear()
+        # index is empty again
+        self.assertEqual(ri.getSerialList(oid1), [])
+        self.assertEqual(ri.getSerialList(oid2), [])
+
+    def test_RevisionIndex_2(self):
+        # Test getLatestSerial & getSerialBefore with invalidations
+        oid1 = self.getOID(1)
+        tid1 = self.getOID(1)
+        tid2 = self.getOID(2)
+        tid3 = self.getOID(3)
+        tid4 = self.getOID(4)
+        tid5 = self.getOID(5)
+        tid6 = self.getOID(6)
+        ri = RevisionIndex()
+        ri.add((oid1, tid1))
+        ri.add((oid1, tid2))
+        self.assertEqual(ri.getLatestSerial(oid1), tid2)
+        self.assertEqual(ri.getSerialBefore(oid1, tid2), tid1)
+        self.assertEqual(ri.getSerialBefore(oid1, tid3), tid2)
+        self.assertEqual(ri.getSerialBefore(oid1, tid4), tid2)
+        ri.invalidate([oid1], tid3)
+        # We don't have the latest data in cache, return None
+        self.assertEqual(ri.getLatestSerial(oid1), None)
+        self.assertEqual(ri.getSerialBefore(oid1, tid2), tid1)
+        self.assertEqual(ri.getSerialBefore(oid1, tid3), tid2)
+        # There is a gap between the last version we have and requested one,
+        # return None
+        self.assertEqual(ri.getSerialBefore(oid1, tid4), None)
+        ri.add((oid1, tid3))
+        # No gap anymore, tid3 found.
+        self.assertEqual(ri.getLatestSerial(oid1), tid3)
+        self.assertEqual(ri.getSerialBefore(oid1, tid4), tid3)
+        ri.invalidate([oid1], tid4)
+        ri.invalidate([oid1], tid5)
+        # A bigger gap...
+        self.assertEqual(ri.getLatestSerial(oid1), None)
+        self.assertEqual(ri.getSerialBefore(oid1, tid5), None)
+        self.assertEqual(ri.getSerialBefore(oid1, tid6), None)
+        # not entirely filled.
+        ri.add((oid1, tid5))
+        # Still, we know the latest and what is before tid6
+        self.assertEqual(ri.getLatestSerial(oid1), tid5)
+        self.assertEqual(ri.getSerialBefore(oid1, tid5), None)
+        self.assertEqual(ri.getSerialBefore(oid1, tid6), tid5)
+
 if __name__ == '__main__':
     unittest.main()
 

Modified: trunk/neo/tests/client/testMasterHandler.py
==============================================================================
--- trunk/neo/tests/client/testMasterHandler.py [iso-8859-1] (original)
+++ trunk/neo/tests/client/testMasterHandler.py [iso-8859-1] Tue Dec 14 16:56:50 2010
@@ -156,17 +156,22 @@ class MasterNotificationsHandlerTests(Ma
     def test_invalidateObjects(self):
         conn = self.getConnection()
         tid = self.getNextTID()
-        oid1, oid2 = self.getOID(1), self.getOID(2)
+        oid1, oid2, oid3 = self.getOID(1), self.getOID(2), self.getOID(3)
         self.app.mq_cache = {
-            oid1: tid,
-            oid2: tid,
+            (oid1, tid): ('bla', None),
+            (oid2, tid): ('bla', None),
         }
-        self.handler.invalidateObjects(conn, tid, [oid1])
-        self.assertFalse(oid1 in self.app.mq_cache)
-        self.assertTrue(oid2 in self.app.mq_cache)
+        self.app.cache_revision_index = Mock({
+            'invalidate': None,
+        })
+        self.handler.invalidateObjects(conn, tid, [oid1, oid3])
+        cache_calls = self.app.cache_revision_index.mockGetNamedCalls(
+            'invalidate')
+        self.assertEqual(len(cache_calls), 1)
+        cache_calls[0].checkArgs([oid1, oid3], tid)
         invalidation_calls = self.db.mockGetNamedCalls('invalidate')
         self.assertEqual(len(invalidation_calls), 1)
-        invalidation_calls[0].checkArgs(tid, {oid1:tid})
+        invalidation_calls[0].checkArgs(tid, {oid1:tid, oid3:tid})
 
     def test_notifyPartitionChanges(self):
         conn = self.getConnection()




More information about the Neo-report mailing list