001 /* 002 * Copyright 2010-2013 JetBrains s.r.o. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017 package org.jetbrains.jet.storage; 018 019 import kotlin.Function0; 020 import kotlin.Function1; 021 import kotlin.Unit; 022 import org.jetbrains.annotations.NotNull; 023 import org.jetbrains.annotations.Nullable; 024 import org.jetbrains.jet.utils.UtilsPackage; 025 import org.jetbrains.jet.utils.WrappedValues; 026 027 import java.util.concurrent.ConcurrentHashMap; 028 import java.util.concurrent.ConcurrentMap; 029 import java.util.concurrent.locks.Lock; 030 import java.util.concurrent.locks.ReentrantLock; 031 032 public class LockBasedStorageManager implements StorageManager { 033 034 public interface ExceptionHandlingStrategy { 035 ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() { 036 @NotNull 037 @Override 038 public RuntimeException handleException(@NotNull Throwable throwable) { 039 throw UtilsPackage.rethrow(throwable); 040 } 041 }; 042 043 /* 044 * The signature of this method is a trick: it is used as 045 * 046 * throw strategy.handleException(...) 047 * 048 * most implementations of this method throw exceptions themselves, so it does not matter what they return 049 */ 050 @NotNull 051 RuntimeException handleException(@NotNull Throwable throwable); 052 } 053 054 public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) { 055 @NotNull 056 @Override 057 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 058 return RecursionDetectedResult.fallThrough(); 059 } 060 }; 061 062 @NotNull 063 public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) { 064 return new LockBasedStorageManager(exceptionHandlingStrategy); 065 } 066 067 protected final Lock lock; 068 private final ExceptionHandlingStrategy exceptionHandlingStrategy; 069 private final String debugText; 070 071 private LockBasedStorageManager( 072 @NotNull String debugText, 073 @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy, 074 @NotNull Lock lock 075 ) { 076 this.lock = lock; 077 this.exceptionHandlingStrategy = exceptionHandlingStrategy; 078 this.debugText = debugText; 079 } 080 081 public LockBasedStorageManager() { 082 this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock()); 083 } 084 085 protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) { 086 this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock()); 087 } 088 089 private static String getPointOfConstruction() { 090 StackTraceElement[] trace = Thread.currentThread().getStackTrace(); 091 // we need to skip frames for getStackTrace(), this method and the constructor that's calling it 092 if (trace.length <= 3) return "<unknown creating class>"; 093 return trace[3].toString(); 094 } 095 096 @Override 097 public String toString() { 098 return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")"; 099 } 100 101 @NotNull 102 @Override 103 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) { 104 return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>()); 105 } 106 107 @NotNull 108 protected <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction( 109 @NotNull Function1<? super K, ? extends V> compute, 110 @NotNull ConcurrentMap<K, Object> map 111 ) { 112 return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute); 113 } 114 115 @NotNull 116 @Override 117 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) { 118 return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>()); 119 } 120 121 @NotNull 122 protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues( 123 @NotNull Function1<? super K, ? extends V> compute, 124 @NotNull ConcurrentMap<K, Object> map 125 ) { 126 return new MapBasedMemoizedFunction<K, V>(map, compute); 127 } 128 129 @NotNull 130 @Override 131 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) { 132 return new LockBasedNotNullLazyValue<T>(computable); 133 } 134 135 @NotNull 136 @Override 137 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue( 138 @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall 139 ) { 140 return new LockBasedNotNullLazyValue<T>(computable) { 141 @NotNull 142 @Override 143 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 144 return RecursionDetectedResult.value(onRecursiveCall); 145 } 146 }; 147 } 148 149 @NotNull 150 @Override 151 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute( 152 @NotNull Function0<? extends T> computable, 153 final Function1<? super Boolean, ? extends T> onRecursiveCall, 154 @NotNull final Function1<? super T, ? extends Unit> postCompute 155 ) { 156 return new LockBasedNotNullLazyValue<T>(computable) { 157 @NotNull 158 @Override 159 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 160 if (onRecursiveCall == null) { 161 return super.recursionDetected(firstTime); 162 } 163 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime)); 164 } 165 166 @Override 167 protected void postCompute(@NotNull T value) { 168 postCompute.invoke(value); 169 } 170 }; 171 } 172 173 @NotNull 174 @Override 175 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) { 176 return new LockBasedLazyValue<T>(computable); 177 } 178 179 @NotNull 180 @Override 181 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) { 182 return new LockBasedLazyValue<T>(computable) { 183 @NotNull 184 @Override 185 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 186 return RecursionDetectedResult.value(onRecursiveCall); 187 } 188 }; 189 } 190 191 @NotNull 192 @Override 193 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute( 194 @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, ? extends Unit> postCompute 195 ) { 196 return new LockBasedLazyValue<T>(computable) { 197 @Override 198 protected void postCompute(@Nullable T value) { 199 postCompute.invoke(value); 200 } 201 }; 202 } 203 204 @Override 205 public <T> T compute(@NotNull Function0<? extends T> computable) { 206 lock.lock(); 207 try { 208 return computable.invoke(); 209 } 210 catch (Throwable throwable) { 211 throw exceptionHandlingStrategy.handleException(throwable); 212 } 213 finally { 214 lock.unlock(); 215 } 216 } 217 218 @NotNull 219 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 220 throw new IllegalStateException("Recursive call in a lazy value under " + this); 221 } 222 223 private static class RecursionDetectedResult<T> { 224 225 @NotNull 226 public static <T> RecursionDetectedResult<T> value(T value) { 227 return new RecursionDetectedResult<T>(value, false); 228 } 229 230 @NotNull 231 public static <T> RecursionDetectedResult<T> fallThrough() { 232 return new RecursionDetectedResult<T>(null, true); 233 } 234 235 private final T value; 236 private final boolean fallThrough; 237 238 private RecursionDetectedResult(T value, boolean fallThrough) { 239 this.value = value; 240 this.fallThrough = fallThrough; 241 } 242 243 public T getValue() { 244 assert !fallThrough : "A value requested from FALL_THROUGH in " + this; 245 return value; 246 } 247 248 public boolean isFallThrough() { 249 return fallThrough; 250 } 251 252 @Override 253 public String toString() { 254 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value); 255 } 256 } 257 258 private enum NotValue { 259 NOT_COMPUTED, 260 COMPUTING, 261 RECURSION_WAS_DETECTED 262 } 263 264 private class LockBasedLazyValue<T> implements NullableLazyValue<T> { 265 266 private final Function0<? extends T> computable; 267 268 @Nullable 269 private volatile Object value = NotValue.NOT_COMPUTED; 270 271 public LockBasedLazyValue(@NotNull Function0<? extends T> computable) { 272 this.computable = computable; 273 } 274 275 @Override 276 public boolean isComputed() { 277 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING; 278 } 279 280 @Override 281 public T invoke() { 282 Object _value = value; 283 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 284 285 lock.lock(); 286 try { 287 _value = value; 288 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 289 290 if (_value == NotValue.COMPUTING) { 291 value = NotValue.RECURSION_WAS_DETECTED; 292 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true); 293 if (!result.isFallThrough()) { 294 return result.getValue(); 295 } 296 } 297 298 if (_value == NotValue.RECURSION_WAS_DETECTED) { 299 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false); 300 if (!result.isFallThrough()) { 301 return result.getValue(); 302 } 303 } 304 305 value = NotValue.COMPUTING; 306 try { 307 T typedValue = computable.invoke(); 308 value = typedValue; 309 postCompute(typedValue); 310 return typedValue; 311 } 312 catch (Throwable throwable) { 313 if (value == NotValue.COMPUTING) { 314 // Store only if it's a genuine result, not something thrown through recursionDetected() 315 value = WrappedValues.escapeThrowable(throwable); 316 } 317 throw exceptionHandlingStrategy.handleException(throwable); 318 } 319 } 320 finally { 321 lock.unlock(); 322 } 323 } 324 325 /** 326 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise 327 * @return a value to be returned on a recursive call or subsequent calls 328 */ 329 @NotNull 330 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 331 return recursionDetectedDefault(); 332 } 333 334 protected void postCompute(T value) { 335 // Doing something in post-compute helps prevent infinite recursion 336 } 337 } 338 339 private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> { 340 341 public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) { 342 super(computable); 343 } 344 345 @Override 346 @NotNull 347 public T invoke() { 348 T result = super.invoke(); 349 assert result != null : "compute() returned null"; 350 return result; 351 } 352 } 353 354 private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> { 355 private final ConcurrentMap<K, Object> cache; 356 private final Function1<? super K, ? extends V> compute; 357 358 public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) { 359 this.cache = map; 360 this.compute = compute; 361 } 362 363 @Override 364 @Nullable 365 public V invoke(K input) { 366 Object value = cache.get(input); 367 if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value); 368 369 lock.lock(); 370 try { 371 value = cache.get(input); 372 assert value != NotValue.COMPUTING : "Recursion detected on input: " + input + " under " + LockBasedStorageManager.this; 373 if (value != null) return WrappedValues.unescapeExceptionOrNull(value); 374 375 AssertionError error = null; 376 try { 377 cache.put(input, NotValue.COMPUTING); 378 V typedValue = compute.invoke(input); 379 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue)); 380 381 // This code effectively asserts that oldValue is null 382 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored 383 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that 384 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed 385 if (oldValue != NotValue.COMPUTING) { 386 error = new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue + 387 " under " + LockBasedStorageManager.this); 388 throw error; 389 } 390 391 return typedValue; 392 } 393 catch (Throwable throwable) { 394 if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable); 395 396 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable)); 397 assert oldValue == NotValue.COMPUTING : "Race condition detected on input " + input + ". Old value is " + oldValue + 398 " under " + LockBasedStorageManager.this; 399 400 throw exceptionHandlingStrategy.handleException(throwable); 401 } 402 } 403 finally { 404 lock.unlock(); 405 } 406 } 407 } 408 409 private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> { 410 411 public MapBasedMemoizedFunctionToNotNull( 412 @NotNull ConcurrentMap<K, Object> map, 413 @NotNull Function1<? super K, ? extends V> compute 414 ) { 415 super(map, compute); 416 } 417 418 @NotNull 419 @Override 420 public V invoke(K input) { 421 V result = super.invoke(input); 422 assert result != null : "compute() returned null under " + LockBasedStorageManager.this; 423 return result; 424 } 425 } 426 427 @NotNull 428 public static LockBasedStorageManager createDelegatingWithSameLock( 429 @NotNull LockBasedStorageManager base, 430 @NotNull ExceptionHandlingStrategy newStrategy 431 ) { 432 return new LockBasedStorageManager(getPointOfConstruction(), newStrategy, base.lock); 433 } 434 }