< prev index next >

src/java.base/share/classes/sun/security/util/Cache.java

Print this page

        

@@ -25,10 +25,14 @@
 
 package sun.security.util;
 
 import java.util.*;
 import java.lang.ref.*;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Predicate;
 
 /**
  * Abstract base class and factory for caches. A cache is a key-value mapping.
  * It has properties that make it more suitable for caching than a Map.
  *

@@ -238,311 +242,428 @@
 
 }
 
 class MemoryCache<K,V> extends Cache<K,V> {
 
-    private static final float LOAD_FACTOR = 0.75f;
-
     // XXXX
     private static final boolean DEBUG = false;
 
-    private final Map<K, CacheEntry<K,V>> cacheMap;
-    private int maxSize;
-    private long lifetime;
+    private final ConcurrentMap<K, CacheEntry<K,V>> cacheMap;
+    private CacheEntry<K, V> lruEntry, mruEntry;
+    private volatile int maxSize;
+    private volatile long lifetime;
 
     // ReferenceQueue is of type V instead of Cache<K,V>
     // to allow SoftCacheEntry to extend SoftReference<V>
     private final ReferenceQueue<V> queue;
 
     public MemoryCache(boolean soft, int maxSize) {
-        this(soft, maxSize, 0);
+        this(soft, maxSize, 0 /* no time out */);
     }
 
-    public MemoryCache(boolean soft, int maxSize, int lifetime) {
+    public MemoryCache(boolean soft, int maxSize, int timeout) {
         this.maxSize = maxSize;
-        this.lifetime = lifetime * 1000;
+        this.lifetime = timeout < 0 ? 0 : TimeUnit.SECONDS.toNanos(timeout);
         if (soft)
             this.queue = new ReferenceQueue<>();
         else
             this.queue = null;
 
-        int buckets = (int)(maxSize / LOAD_FACTOR) + 1;
-        cacheMap = new LinkedHashMap<>(buckets, LOAD_FACTOR, true);
+        cacheMap = new ConcurrentHashMap<>(maxSize);
     }
 
     /**
-     * Empty the reference queue and remove all corresponding entries
+     * Drain the reference queue and remove all corresponding entries
      * from the cache.
      *
      * This method should be called at the beginning of each public
      * method.
      */
-    private void emptyQueue() {
+    private void drainQueue() {
         if (queue == null) {
             return;
         }
-        int startSize = cacheMap.size();
+        int removedCount = 0;
         while (true) {
             @SuppressWarnings("unchecked")
-            CacheEntry<K,V> entry = (CacheEntry<K,V>)queue.poll();
-            if (entry == null) {
+            CacheEntry<K,V> ce = (CacheEntry<K,V>)queue.poll();
+            if (ce == null) {
                 break;
             }
-            K key = entry.getKey();
-            if (key == null) {
-                // key is null, entry has already been removed
-                continue;
-            }
-            CacheEntry<K,V> currentEntry = cacheMap.remove(key);
-            // check if the entry in the map corresponds to the expired
-            // entry. If not, readd the entry
-            if ((currentEntry != null) && (entry != currentEntry)) {
-                cacheMap.put(key, currentEntry);
+            if (ce.unlink(this)) {
+                if (cacheMap.remove(ce.getKey(), ce)) {
+                    removedCount++;
+                }
+                // the one who unlinks is responsible for invalidating
+                ce.invalidate();
             }
         }
         if (DEBUG) {
-            int endSize = cacheMap.size();
-            if (startSize != endSize) {
-                System.out.println("*** Expunged " + (startSize - endSize)
-                        + " entries, " + endSize + " entries left");
+            if (removedCount > 0) {
+                System.out.println("*** Expunged " + removedCount
+                        + " entries, " + cacheMap.size() + " entries left");
             }
         }
     }
 
     /**
-     * Scan all entries and remove all expired ones.
+     * Drain the reference queue and remove all corresponding entries
+     * from the cache. Then expunge all expired entries.
+     *
+     * This method should be called just before doing any decisions based
+     * on the number of entries remaining in cache (i.e. cacheMap.size()).
      */
-    private void expungeExpiredEntries() {
-        emptyQueue();
+    private void drainQueueExpungeExpired() {
+        drainQueue();
+        int cnt = 0;
+        long currentTime = System.nanoTime();
+        long lifetime = this.lifetime;
         if (lifetime == 0) {
             return;
         }
-        int cnt = 0;
-        long time = System.currentTimeMillis();
-        for (Iterator<CacheEntry<K,V>> t = cacheMap.values().iterator();
-                t.hasNext(); ) {
-            CacheEntry<K,V> entry = t.next();
-            if (entry.isValid(time) == false) {
-                t.remove();
+        Predicate<CacheEntry<K, V>> entryIsInvalid =
+            ce -> !ce.isValid(currentTime, lifetime);
+        CacheEntry<K,V> lru;
+        while ((lru = CacheEntry.unlinkLruIf(this, entryIsInvalid)) != null) {
+            if (cacheMap.remove(lru.getKey(), lru)) {
                 cnt++;
             }
+            // the one who unlinks is responsible for invalidating
+            lru.invalidate();
         }
         if (DEBUG) {
             if (cnt != 0) {
                 System.out.println("Removed " + cnt
                         + " expired entries, remaining " + cacheMap.size());
             }
         }
     }
 
-    public synchronized int size() {
-        expungeExpiredEntries();
-        return cacheMap.size();
+     /**
+     * Remove LRU entries while cache size is greater than given {@code maxSize}.
+     *
+     * @param maxSize pack cache to size no more than
+     */
+    private void reduceToMaxSize(int maxSize) {
+        if (maxSize > 0 && cacheMap.size() > maxSize) {
+            // 1st get rid of all cleared and expired entries
+            drainQueueExpungeExpired();
+            while (cacheMap.size() > maxSize) { // while still too large
+                CacheEntry<K, V> lru = CacheEntry.unlinkLruIf(this, ce -> true);
+                if (lru != null) {
+                    if (DEBUG) {
+                        System.out.println("** Overflow removal "
+                            + lru.getKey() + " | " + lru.getValue());
     }
-
-    public synchronized void clear() {
-        if (queue != null) {
-            // if this is a SoftReference cache, first invalidate() all
-            // entries so that GC does not have to enqueue them
-            for (CacheEntry<K,V> entry : cacheMap.values()) {
-                entry.invalidate();
+                    cacheMap.remove(lru.getKey(), lru);
+                    // the one who unlinks is responsible for invalidating
+                    lru.invalidate();
             }
-            while (queue.poll() != null) {
-                // empty
             }
         }
-        cacheMap.clear();
     }
 
-    public synchronized void put(K key, V value) {
-        emptyQueue();
-        long expirationTime = (lifetime == 0) ? 0 :
-                                        System.currentTimeMillis() + lifetime;
-        CacheEntry<K,V> newEntry = newEntry(key, value, expirationTime, queue);
-        CacheEntry<K,V> oldEntry = cacheMap.put(key, newEntry);
-        if (oldEntry != null) {
-            oldEntry.invalidate();
-            return;
+   public int size() {
+        drainQueueExpungeExpired();
+        return cacheMap.size();
         }
-        if (maxSize > 0 && cacheMap.size() > maxSize) {
-            expungeExpiredEntries();
-            if (cacheMap.size() > maxSize) { // still too large?
-                Iterator<CacheEntry<K,V>> t = cacheMap.values().iterator();
-                CacheEntry<K,V> lruEntry = t.next();
-                if (DEBUG) {
-                    System.out.println("** Overflow removal "
-                        + lruEntry.getKey() + " | " + lruEntry.getValue());
+
+    public void clear() {
+        CacheEntry<K, V> lru;
+        while ((lru = CacheEntry.unlinkLruIf(this, ce -> true)) != null) {
+            cacheMap.remove(lru.getKey(), lru);
+            // the one who unlinks is responsible for invalidating
+            lru.invalidate();
                 }
-                t.remove();
-                lruEntry.invalidate();
             }
+
+    public void put(K key, V value) {
+        drainQueue();
+        CacheEntry<K,V> newEntry = newEntry(key, value, queue);
+        newEntry.link(this);
+        CacheEntry<K,V> oldEntry = cacheMap.put(key, newEntry);
+        if (oldEntry != null && oldEntry.unlink(this)) {
+            // the one who unlinks is responsible for invalidating
+            oldEntry.invalidate();
+            return;
         }
+        reduceToMaxSize(maxSize);
     }
 
-    public synchronized V get(Object key) {
-        emptyQueue();
+    public V get(Object key) {
+        drainQueue();
         CacheEntry<K,V> entry = cacheMap.get(key);
         if (entry == null) {
             return null;
         }
-        long time = (lifetime == 0) ? 0 : System.currentTimeMillis();
-        if (entry.isValid(time) == false) {
+        // harmless data race: entry.isValid() vs. entry.invalidate()
+        if (!entry.isValid(System.nanoTime(), lifetime)) {
             if (DEBUG) {
                 System.out.println("Ignoring expired entry");
             }
-            cacheMap.remove(key);
+            if (entry.unlink(this)) {
+                cacheMap.remove(entry.getKey(), entry);
+                // the one who unlinks is responsible for invalidating
+                entry.invalidate();
+            }
             return null;
         }
-        return entry.getValue();
+        return entry.getValue(); // harmless data race with entry.invalidate()
     }
 
-    public synchronized void remove(Object key) {
-        emptyQueue();
-        CacheEntry<K,V> entry = cacheMap.remove(key);
+    public void remove(Object key) {
+        drainQueue();
+        CacheEntry<K,V> entry = cacheMap.get(key);
         if (entry != null) {
+            if (entry.unlink(this)) {
+                cacheMap.remove(entry.getKey(), entry);
+                // the one who unlinks is responsible for invalidating
             entry.invalidate();
         }
     }
-
-    public synchronized void setCapacity(int size) {
-        expungeExpiredEntries();
-        if (size > 0 && cacheMap.size() > size) {
-            Iterator<CacheEntry<K,V>> t = cacheMap.values().iterator();
-            for (int i = cacheMap.size() - size; i > 0; i--) {
-                CacheEntry<K,V> lruEntry = t.next();
-                if (DEBUG) {
-                    System.out.println("** capacity reset removal "
-                        + lruEntry.getKey() + " | " + lruEntry.getValue());
-                }
-                t.remove();
-                lruEntry.invalidate();
-            }
         }
 
-        maxSize = size > 0 ? size : 0;
+    public void setCapacity(int size) {
+        if (size < 0) size = 0;
+        maxSize = size;
+        // in case maxSize was reduces, immediately reduce the cache too
+        reduceToMaxSize(size);
 
         if (DEBUG) {
             System.out.println("** capacity reset to " + size);
         }
     }
 
-    public synchronized void setTimeout(int timeout) {
-        emptyQueue();
-        lifetime = timeout > 0 ? timeout * 1000L : 0L;
-
+    public void setTimeout(int timeout) {
+        this.lifetime = timeout < 0 ? 0 : TimeUnit.SECONDS.toNanos(timeout);
+        // in case timeout was shortened, immediately expunge newly expired entries
+        drainQueueExpungeExpired();
         if (DEBUG) {
-            System.out.println("** lifetime reset to " + timeout);
+            System.out.println("** lifetime reset to " + lifetime + " nanos");
         }
     }
 
     // it is a heavyweight method.
-    public synchronized void accept(CacheVisitor<K,V> visitor) {
-        expungeExpiredEntries();
+    public void accept(CacheVisitor<K,V> visitor) {
         Map<K,V> cached = getCachedEntries();
-
         visitor.visit(cached);
     }
 
+    // return a snapshot of the valid/non-expired part of the cache
     private Map<K,V> getCachedEntries() {
+        drainQueueExpungeExpired();
         Map<K,V> kvmap = new HashMap<>(cacheMap.size());
 
         for (CacheEntry<K,V> entry : cacheMap.values()) {
-            kvmap.put(entry.getKey(), entry.getValue());
+            K key = entry.getKey(); // harmless data race with entry.invalidate()
+            V val = entry.getValue(); // harmless data race with entry.invalidate()
+            if (key != null && val != null) {
+                kvmap.put(key, val);
+            }
         }
 
         return kvmap;
     }
 
-    protected CacheEntry<K,V> newEntry(K key, V value,
-            long expirationTime, ReferenceQueue<V> queue) {
+    protected CacheEntry<K,V> newEntry(K key, V value, ReferenceQueue<V> queue) {
         if (queue != null) {
-            return new SoftCacheEntry<>(key, value, expirationTime, queue);
+            return new SoftCacheEntry<>(key, value, queue);
         } else {
-            return new HardCacheEntry<>(key, value, expirationTime);
+            return new HardCacheEntry<>(key, value);
         }
     }
 
-    private static interface CacheEntry<K,V> {
+    private interface CacheEntry<K, V> {
 
-        boolean isValid(long currentTime);
+        boolean isValid(long currentTime, long lifetime);
 
         void invalidate();
 
         K getKey();
 
         V getValue();
 
+        // double-linked-list management
+
+        CacheEntry<K,V> prev();
+
+        CacheEntry<K,V> next();
+
+        void setPrev(CacheEntry<K,V> newPrev);
+
+        void setNext(CacheEntry<K,V> newNext);
+
+        void entryLinked();
+
+        default void link(MemoryCache<K, V> memoryCache) {
+            synchronized (memoryCache) {
+                assert prev() == this && next() == this : "Entry already linked";
+                if (memoryCache.lruEntry == null) {
+                    memoryCache.lruEntry = memoryCache.mruEntry = this;
+                    setPrev(null); setNext(null);
+                } else {
+                    setPrev(memoryCache.mruEntry);
+                    memoryCache.mruEntry.setNext(this);
+                    memoryCache.mruEntry = this;
+                    setNext(null);
+                }
+                entryLinked();
+            }
+        }
+
+        default boolean unlink(MemoryCache<K, V> memoryCache) {
+            synchronized (memoryCache) {
+                CacheEntry<K, V> next = next();
+                CacheEntry<K, V> prev = prev();
+                if (next == this && prev == this) {
+                    // not linked
+                    return false;
+                }
+                if (memoryCache.lruEntry == this) {
+                    memoryCache.lruEntry = next;
+                }
+                if (memoryCache.mruEntry == this) {
+                    memoryCache.mruEntry = prev;
+                }
+                if (prev != null) {
+                    prev.setNext(next);
+                }
+                if (next != null) {
+                    next.setPrev(prev);
+                }
+                setPrev(this);
+                setNext(this);
+                return true;
+            }
+        }
+
+        static <K, V> CacheEntry<K, V> unlinkLruIf(MemoryCache<K, V> memoryCache, Predicate<CacheEntry<K, V>> predicate) {
+            synchronized (memoryCache) {
+                CacheEntry<K, V> lru = memoryCache.lruEntry;
+                if (lru == null || !predicate.test(lru)) {
+                    return null;
+                }
+                return lru.unlink(memoryCache) ? lru : null;
+            }
+        }
     }
 
     private static class HardCacheEntry<K,V> implements CacheEntry<K,V> {
 
         private K key;
         private V value;
-        private long expirationTime;
+        private long linkTime;
+        private CacheEntry<K, V> prev = this, next = this;
 
-        HardCacheEntry(K key, V value, long expirationTime) {
+        HardCacheEntry(K key, V value) {
             this.key = key;
             this.value = value;
-            this.expirationTime = expirationTime;
         }
 
         public K getKey() {
             return key;
         }
 
         public V getValue() {
             return value;
         }
 
-        public boolean isValid(long currentTime) {
-            boolean valid = (currentTime <= expirationTime);
-            if (valid == false) {
-                invalidate();
-            }
-            return valid;
+        public boolean isValid(long currentTime, long lifetime) {
+            return value != null &&
+                   (lifetime == 0 || (currentTime - linkTime) <= lifetime);
         }
 
         public void invalidate() {
             key = null;
             value = null;
-            expirationTime = -1;
+        }
+
+        @Override
+        public CacheEntry<K, V> prev() {
+            return prev;
+        }
+
+        @Override
+        public CacheEntry<K, V> next() {
+            return next;
+        }
+
+        @Override
+        public void setPrev(CacheEntry<K, V> newPrev) {
+            prev = newPrev;
+        }
+
+        @Override
+        public void setNext(CacheEntry<K, V> newNext) {
+            next = newNext;
+        }
+
+        @Override
+        public void entryLinked() {
+            // sample link time while synchronized which guarantees
+            // monotonic time increment (no decrement), so dl-list
+            // is kept sorted by linkTime
+            linkTime = System.nanoTime();
         }
     }
 
     private static class SoftCacheEntry<K,V>
             extends SoftReference<V>
             implements CacheEntry<K,V> {
 
         private K key;
-        private long expirationTime;
+        private long linkTime;
+        private CacheEntry<K, V> prev = this, next = this;
 
-        SoftCacheEntry(K key, V value, long expirationTime,
-                ReferenceQueue<V> queue) {
+        SoftCacheEntry(K key, V value, ReferenceQueue<V> queue) {
             super(value, queue);
             this.key = key;
-            this.expirationTime = expirationTime;
         }
 
         public K getKey() {
             return key;
         }
 
         public V getValue() {
             return get();
         }
 
-        public boolean isValid(long currentTime) {
-            boolean valid = (currentTime <= expirationTime) && (get() != null);
-            if (valid == false) {
-                invalidate();
-            }
-            return valid;
+        public boolean isValid(long currentTime, long lifetime) {
+            return get() != null &&
+                   (lifetime == 0 || (currentTime - linkTime) <= lifetime);
         }
 
         public void invalidate() {
-            clear();
             key = null;
-            expirationTime = -1;
+            clear();
+        }
+
+        @Override
+        public CacheEntry<K, V> prev() {
+            return prev;
+        }
+
+        @Override
+        public CacheEntry<K, V> next() {
+            return next;
+        }
+
+        @Override
+        public void setPrev(CacheEntry<K, V> newPrev) {
+            prev = newPrev;
+        }
+
+        @Override
+        public void setNext(CacheEntry<K, V> newNext) {
+            next = newNext;
+        }
+
+        @Override
+        public void entryLinked() {
+            // sample link time while synchronized which guarantees
+            // monotonic time increment (no decrement), so dl-list
+            // is kept sorted by linkTime
+            linkTime = System.nanoTime();
         }
     }
 
 }
< prev index next >