001 /* 002 * Copyright 2010-2015 JetBrains s.r.o. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017 package org.jetbrains.kotlin.storage; 018 019 import kotlin.Unit; 020 import kotlin.jvm.functions.Function0; 021 import kotlin.jvm.functions.Function1; 022 import org.jetbrains.annotations.NotNull; 023 import org.jetbrains.annotations.Nullable; 024 import org.jetbrains.kotlin.utils.UtilsPackage; 025 import org.jetbrains.kotlin.utils.WrappedValues; 026 027 import java.util.Arrays; 028 import java.util.List; 029 import java.util.concurrent.ConcurrentHashMap; 030 import java.util.concurrent.ConcurrentMap; 031 import java.util.concurrent.locks.Lock; 032 import java.util.concurrent.locks.ReentrantLock; 033 034 public class LockBasedStorageManager implements StorageManager { 035 public interface ExceptionHandlingStrategy { 036 ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() { 037 @NotNull 038 @Override 039 public RuntimeException handleException(@NotNull Throwable throwable) { 040 throw UtilsPackage.rethrow(throwable); 041 } 042 }; 043 044 /* 045 * The signature of this method is a trick: it is used as 046 * 047 * throw strategy.handleException(...) 048 * 049 * most implementations of this method throw exceptions themselves, so it does not matter what they return 050 */ 051 @NotNull 052 RuntimeException handleException(@NotNull Throwable throwable); 053 } 054 055 public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) { 056 @NotNull 057 @Override 058 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 059 return RecursionDetectedResult.fallThrough(); 060 } 061 }; 062 063 @NotNull 064 public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) { 065 return new LockBasedStorageManager(exceptionHandlingStrategy); 066 } 067 068 protected final Lock lock; 069 private final ExceptionHandlingStrategy exceptionHandlingStrategy; 070 private final String debugText; 071 072 private LockBasedStorageManager( 073 @NotNull String debugText, 074 @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy, 075 @NotNull Lock lock 076 ) { 077 this.lock = lock; 078 this.exceptionHandlingStrategy = exceptionHandlingStrategy; 079 this.debugText = debugText; 080 } 081 082 public LockBasedStorageManager() { 083 this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock()); 084 } 085 086 protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) { 087 this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock()); 088 } 089 090 private static String getPointOfConstruction() { 091 StackTraceElement[] trace = Thread.currentThread().getStackTrace(); 092 // we need to skip frames for getStackTrace(), this method and the constructor that's calling it 093 if (trace.length <= 3) return "<unknown creating class>"; 094 return trace[3].toString(); 095 } 096 097 @Override 098 public String toString() { 099 return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")"; 100 } 101 102 @NotNull 103 @Override 104 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) { 105 return createMemoizedFunction(compute, LockBasedStorageManager.<K>createConcurrentHashMap()); 106 } 107 108 @NotNull 109 @Override 110 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction( 111 @NotNull Function1<? super K, ? extends V> compute, 112 @NotNull ConcurrentMap<K, Object> map 113 ) { 114 return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute); 115 } 116 117 @NotNull 118 @Override 119 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) { 120 return createMemoizedFunctionWithNullableValues(compute, LockBasedStorageManager.<K>createConcurrentHashMap()); 121 } 122 123 @Override 124 @NotNull 125 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues( 126 @NotNull Function1<? super K, ? extends V> compute, 127 @NotNull ConcurrentMap<K, Object> map 128 ) { 129 return new MapBasedMemoizedFunction<K, V>(map, compute); 130 } 131 132 @NotNull 133 @Override 134 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) { 135 return new LockBasedNotNullLazyValue<T>(computable); 136 } 137 138 @NotNull 139 @Override 140 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue( 141 @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall 142 ) { 143 return new LockBasedNotNullLazyValue<T>(computable) { 144 @NotNull 145 @Override 146 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 147 return RecursionDetectedResult.value(onRecursiveCall); 148 } 149 }; 150 } 151 152 @NotNull 153 @Override 154 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute( 155 @NotNull Function0<? extends T> computable, 156 final Function1<? super Boolean, ? extends T> onRecursiveCall, 157 @NotNull final Function1<? super T, ? extends Unit> postCompute 158 ) { 159 return new LockBasedNotNullLazyValue<T>(computable) { 160 @NotNull 161 @Override 162 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 163 if (onRecursiveCall == null) { 164 return super.recursionDetected(firstTime); 165 } 166 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime)); 167 } 168 169 @Override 170 protected void postCompute(@NotNull T value) { 171 postCompute.invoke(value); 172 } 173 }; 174 } 175 176 @NotNull 177 @Override 178 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) { 179 return new LockBasedLazyValue<T>(computable); 180 } 181 182 @NotNull 183 @Override 184 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) { 185 return new LockBasedLazyValue<T>(computable) { 186 @NotNull 187 @Override 188 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 189 return RecursionDetectedResult.value(onRecursiveCall); 190 } 191 }; 192 } 193 194 @NotNull 195 @Override 196 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute( 197 @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, ? extends Unit> postCompute 198 ) { 199 return new LockBasedLazyValue<T>(computable) { 200 @Override 201 protected void postCompute(@Nullable T value) { 202 postCompute.invoke(value); 203 } 204 }; 205 } 206 207 @Override 208 public <T> T compute(@NotNull Function0<? extends T> computable) { 209 lock.lock(); 210 try { 211 return computable.invoke(); 212 } 213 catch (Throwable throwable) { 214 throw exceptionHandlingStrategy.handleException(throwable); 215 } 216 finally { 217 lock.unlock(); 218 } 219 } 220 221 @NotNull 222 private static <K> ConcurrentMap<K, Object> createConcurrentHashMap() { 223 // memory optimization: fewer segments and entries stored 224 return new ConcurrentHashMap<K, Object>(3, 1, 2); 225 } 226 227 @NotNull 228 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 229 throw sanitizeStackTrace(new IllegalStateException("Recursive call in a lazy value under " + this)); 230 } 231 232 private static class RecursionDetectedResult<T> { 233 234 @NotNull 235 public static <T> RecursionDetectedResult<T> value(T value) { 236 return new RecursionDetectedResult<T>(value, false); 237 } 238 239 @NotNull 240 public static <T> RecursionDetectedResult<T> fallThrough() { 241 return new RecursionDetectedResult<T>(null, true); 242 } 243 244 private final T value; 245 private final boolean fallThrough; 246 247 private RecursionDetectedResult(T value, boolean fallThrough) { 248 this.value = value; 249 this.fallThrough = fallThrough; 250 } 251 252 public T getValue() { 253 assert !fallThrough : "A value requested from FALL_THROUGH in " + this; 254 return value; 255 } 256 257 public boolean isFallThrough() { 258 return fallThrough; 259 } 260 261 @Override 262 public String toString() { 263 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value); 264 } 265 } 266 267 private enum NotValue { 268 NOT_COMPUTED, 269 COMPUTING, 270 RECURSION_WAS_DETECTED 271 } 272 273 private class LockBasedLazyValue<T> implements NullableLazyValue<T> { 274 275 private final Function0<? extends T> computable; 276 277 @Nullable 278 private volatile Object value = NotValue.NOT_COMPUTED; 279 280 public LockBasedLazyValue(@NotNull Function0<? extends T> computable) { 281 this.computable = computable; 282 } 283 284 @Override 285 public boolean isComputed() { 286 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING; 287 } 288 289 @Override 290 public T invoke() { 291 Object _value = value; 292 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 293 294 lock.lock(); 295 try { 296 _value = value; 297 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 298 299 if (_value == NotValue.COMPUTING) { 300 value = NotValue.RECURSION_WAS_DETECTED; 301 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true); 302 if (!result.isFallThrough()) { 303 return result.getValue(); 304 } 305 } 306 307 if (_value == NotValue.RECURSION_WAS_DETECTED) { 308 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false); 309 if (!result.isFallThrough()) { 310 return result.getValue(); 311 } 312 } 313 314 value = NotValue.COMPUTING; 315 try { 316 T typedValue = computable.invoke(); 317 value = typedValue; 318 postCompute(typedValue); 319 return typedValue; 320 } 321 catch (Throwable throwable) { 322 if (value == NotValue.COMPUTING) { 323 // Store only if it's a genuine result, not something thrown through recursionDetected() 324 value = WrappedValues.escapeThrowable(throwable); 325 } 326 throw exceptionHandlingStrategy.handleException(throwable); 327 } 328 } 329 finally { 330 lock.unlock(); 331 } 332 } 333 334 /** 335 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise 336 * @return a value to be returned on a recursive call or subsequent calls 337 */ 338 @NotNull 339 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 340 return recursionDetectedDefault(); 341 } 342 343 protected void postCompute(T value) { 344 // Doing something in post-compute helps prevent infinite recursion 345 } 346 } 347 348 private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> { 349 350 public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) { 351 super(computable); 352 } 353 354 @Override 355 @NotNull 356 public T invoke() { 357 T result = super.invoke(); 358 assert result != null : "compute() returned null"; 359 return result; 360 } 361 } 362 363 private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> { 364 private final ConcurrentMap<K, Object> cache; 365 private final Function1<? super K, ? extends V> compute; 366 367 public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) { 368 this.cache = map; 369 this.compute = compute; 370 } 371 372 @Override 373 @Nullable 374 public V invoke(K input) { 375 Object value = cache.get(input); 376 if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value); 377 378 lock.lock(); 379 try { 380 value = cache.get(input); 381 if (value == NotValue.COMPUTING) { 382 throw recursionDetected(input); 383 } 384 if (value != null) return WrappedValues.unescapeExceptionOrNull(value); 385 386 AssertionError error = null; 387 try { 388 cache.put(input, NotValue.COMPUTING); 389 V typedValue = compute.invoke(input); 390 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue)); 391 392 // This code effectively asserts that oldValue is null 393 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored 394 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that 395 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed 396 if (oldValue != NotValue.COMPUTING) { 397 error = raceCondition(input, oldValue); 398 throw error; 399 } 400 401 return typedValue; 402 } 403 catch (Throwable throwable) { 404 if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable); 405 406 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable)); 407 if (oldValue != NotValue.COMPUTING) { 408 throw raceCondition(input, oldValue); 409 } 410 411 throw exceptionHandlingStrategy.handleException(throwable); 412 } 413 } 414 finally { 415 lock.unlock(); 416 } 417 } 418 419 @NotNull 420 private AssertionError recursionDetected(K input) { 421 return sanitizeStackTrace( 422 new AssertionError("Recursion detected on input: " + input + " under " + LockBasedStorageManager.this) 423 ); 424 } 425 426 @NotNull 427 private AssertionError raceCondition(K input, Object oldValue) { 428 return sanitizeStackTrace( 429 new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue + 430 " under " + LockBasedStorageManager.this) 431 ); 432 } 433 434 @Override 435 public boolean isComputed(K key) { 436 Object value = cache.get(key); 437 return value != null && value != NotValue.COMPUTING; 438 } 439 } 440 441 private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> { 442 443 public MapBasedMemoizedFunctionToNotNull( 444 @NotNull ConcurrentMap<K, Object> map, 445 @NotNull Function1<? super K, ? extends V> compute 446 ) { 447 super(map, compute); 448 } 449 450 @NotNull 451 @Override 452 public V invoke(K input) { 453 V result = super.invoke(input); 454 assert result != null : "compute() returned null under " + LockBasedStorageManager.this; 455 return result; 456 } 457 } 458 459 @NotNull 460 public static LockBasedStorageManager createDelegatingWithSameLock( 461 @NotNull LockBasedStorageManager base, 462 @NotNull ExceptionHandlingStrategy newStrategy 463 ) { 464 return new LockBasedStorageManager(getPointOfConstruction(), newStrategy, base.lock); 465 } 466 467 @NotNull 468 private static <T extends Throwable> T sanitizeStackTrace(@NotNull T throwable) { 469 String storagePackageName = LockBasedStorageManager.class.getPackage().getName(); 470 StackTraceElement[] stackTrace = throwable.getStackTrace(); 471 int size = stackTrace.length; 472 473 int firstNonStorage = -1; 474 for (int i = 0; i < size; i++) { 475 // Skip everything (memoized functions and lazy values) from package org.jetbrains.kotlin.storage 476 if (!stackTrace[i].getClassName().startsWith(storagePackageName)) { 477 firstNonStorage = i; 478 break; 479 } 480 } 481 assert firstNonStorage >= 0 : "This method should only be called on exceptions created in LockBasedStorageManager"; 482 483 List<StackTraceElement> list = Arrays.asList(stackTrace).subList(firstNonStorage, size); 484 throwable.setStackTrace(list.toArray(new StackTraceElement[list.size()])); 485 return throwable; 486 } 487 }