001 /* 002 * Copyright 2010-2013 JetBrains s.r.o. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017 package org.jetbrains.jet.storage; 018 019 import jet.Function0; 020 import jet.Function1; 021 import jet.Unit; 022 import org.jetbrains.annotations.NotNull; 023 import org.jetbrains.annotations.Nullable; 024 import org.jetbrains.jet.utils.ExceptionUtils; 025 import org.jetbrains.jet.utils.WrappedValues; 026 027 import java.util.concurrent.ConcurrentHashMap; 028 import java.util.concurrent.ConcurrentMap; 029 import java.util.concurrent.locks.Lock; 030 import java.util.concurrent.locks.ReentrantLock; 031 032 public class LockBasedStorageManager implements StorageManager { 033 034 public static final StorageManager NO_LOCKS = new LockBasedStorageManager(NoLock.INSTANCE) { 035 @Override 036 public String toString() { 037 return "NO_LOCKS"; 038 } 039 }; 040 041 protected final Lock lock; 042 043 public LockBasedStorageManager() { 044 this(new ReentrantLock()); 045 } 046 047 private LockBasedStorageManager(@NotNull Lock lock) { 048 this.lock = lock; 049 } 050 051 @NotNull 052 @Override 053 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<K, V> compute) { 054 return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>()); 055 } 056 057 @NotNull 058 protected <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction( 059 @NotNull Function1<K, V> compute, 060 @NotNull ConcurrentMap<K, Object> map 061 ) { 062 return new MapBasedMemoizedFunctionToNotNull<K, V>(lock, map, compute); 063 } 064 065 @NotNull 066 @Override 067 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<K, V> compute) { 068 return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>()); 069 } 070 071 @NotNull 072 protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues( 073 @NotNull Function1<K, V> compute, 074 @NotNull ConcurrentMap<K, Object> map 075 ) { 076 return new MapBasedMemoizedFunction<K, V>(lock, map, compute); 077 } 078 079 @NotNull 080 @Override 081 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<T> computable) { 082 return new LockBasedNotNullLazyValue<T>(lock, computable); 083 } 084 085 @NotNull 086 @Override 087 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue( 088 @NotNull Function0<T> computable, @NotNull final T onRecursiveCall 089 ) { 090 return new LockBasedNotNullLazyValue<T>(lock, computable) { 091 @Override 092 protected T recursionDetected(boolean firstTime) { 093 return onRecursiveCall; 094 } 095 }; 096 } 097 098 @NotNull 099 @Override 100 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute( 101 @NotNull Function0<T> computable, 102 final Function1<Boolean, T> onRecursiveCall, 103 @NotNull final Function1<T, Unit> postCompute 104 ) { 105 return new LockBasedNotNullLazyValue<T>(lock, computable) { 106 @Nullable 107 @Override 108 protected T recursionDetected(boolean firstTime) { 109 if (onRecursiveCall == null) { 110 return super.recursionDetected(firstTime); 111 } 112 return onRecursiveCall.invoke(firstTime); 113 } 114 115 @Override 116 protected void postCompute(@NotNull T value) { 117 postCompute.invoke(value); 118 } 119 }; 120 } 121 122 @NotNull 123 @Override 124 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<T> computable) { 125 return new LockBasedLazyValue<T>(lock, computable); 126 } 127 128 @NotNull 129 @Override 130 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<T> computable, final T onRecursiveCall) { 131 return new LockBasedLazyValue<T>(lock, computable) { 132 @Override 133 protected T recursionDetected(boolean firstTime) { 134 return onRecursiveCall; 135 } 136 }; 137 } 138 139 @NotNull 140 @Override 141 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute( 142 @NotNull Function0<T> computable, @NotNull final Function1<T, Unit> postCompute 143 ) { 144 return new LockBasedLazyValue<T>(lock, computable) { 145 @Override 146 protected void postCompute(@Nullable T value) { 147 postCompute.invoke(value); 148 } 149 }; 150 } 151 152 @Override 153 public <T> T compute(@NotNull Function0<T> computable) { 154 lock.lock(); 155 try { 156 return computable.invoke(); 157 } 158 finally { 159 lock.unlock(); 160 } 161 } 162 163 private static class LockBasedLazyValue<T> implements NullableLazyValue<T> { 164 165 private enum NotValue { 166 NOT_COMPUTED, 167 COMPUTING, 168 RECURSION_WAS_DETECTED 169 } 170 171 private final Lock lock; 172 private final Function0<T> computable; 173 174 @Nullable 175 private volatile Object value = NotValue.NOT_COMPUTED; 176 177 public LockBasedLazyValue(@NotNull Lock lock, @NotNull Function0<T> computable) { 178 this.lock = lock; 179 this.computable = computable; 180 } 181 182 @Override 183 public boolean isComputed() { 184 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING; 185 } 186 187 @Override 188 public T invoke() { 189 Object _value = value; 190 if (!(value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 191 192 lock.lock(); 193 try { 194 _value = value; 195 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 196 197 if (_value == NotValue.COMPUTING) { 198 value = NotValue.RECURSION_WAS_DETECTED; 199 return recursionDetected(/*firstTime = */ true); 200 } 201 202 if (_value == NotValue.RECURSION_WAS_DETECTED) { 203 return recursionDetected(/*firstTime = */ false); 204 } 205 206 value = NotValue.COMPUTING; 207 try { 208 T typedValue = computable.invoke(); 209 value = typedValue; 210 postCompute(typedValue); 211 return typedValue; 212 } 213 catch (Throwable throwable) { 214 if (value == NotValue.COMPUTING) { 215 // Store only if it's a genuine result, not something thrown through recursionDetected() 216 value = WrappedValues.escapeThrowable(throwable); 217 } 218 throw ExceptionUtils.rethrow(throwable); 219 } 220 } 221 finally { 222 lock.unlock(); 223 } 224 } 225 226 /** 227 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise 228 * @return a value to be returned on a recursive call or subsequent calls 229 */ 230 @Nullable 231 protected T recursionDetected(boolean firstTime) { 232 throw new IllegalStateException("Recursive call in a lazy value"); 233 } 234 235 protected void postCompute(T value) { 236 // Doing something in post-compute helps prevent infinite recursion 237 } 238 } 239 240 private static class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> { 241 242 public LockBasedNotNullLazyValue(@NotNull Lock lock, @NotNull Function0<T> computable) { 243 super(lock, computable); 244 } 245 246 @Override 247 @NotNull 248 public T invoke() { 249 T result = super.invoke(); 250 assert result != null : "compute() returned null"; 251 return result; 252 } 253 } 254 255 private static class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> { 256 private final Lock lock; 257 private final ConcurrentMap<K, Object> cache; 258 private final Function1<K, V> compute; 259 260 public MapBasedMemoizedFunction(@NotNull Lock lock, @NotNull ConcurrentMap<K, Object> map, @NotNull Function1<K, V> compute) { 261 this.lock = lock; 262 this.cache = map; 263 this.compute = compute; 264 } 265 266 @Override 267 @Nullable 268 public V invoke(K input) { 269 Object value = cache.get(input); 270 if (value != null) return WrappedValues.unescapeExceptionOrNull(value); 271 272 lock.lock(); 273 try { 274 value = cache.get(input); 275 if (value != null) return WrappedValues.unescapeExceptionOrNull(value); 276 277 try { 278 V typedValue = compute.invoke(input); 279 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue)); 280 assert oldValue == null : "Race condition or recursion detected. Old value is " + oldValue; 281 282 return typedValue; 283 } 284 catch (Throwable throwable) { 285 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable)); 286 assert oldValue == null : "Race condition or recursion detected. Old value is " + oldValue; 287 288 throw ExceptionUtils.rethrow(throwable); 289 } 290 } 291 finally { 292 lock.unlock(); 293 } 294 } 295 } 296 297 private static class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> { 298 299 public MapBasedMemoizedFunctionToNotNull( 300 @NotNull Lock lock, 301 @NotNull ConcurrentMap<K, Object> map, 302 @NotNull Function1<K, V> compute 303 ) { 304 super(lock, map, compute); 305 } 306 307 @NotNull 308 @Override 309 public V invoke(K input) { 310 V result = super.invoke(input); 311 assert result != null : "compute() returned null"; 312 return result; 313 } 314 } 315 }