001 /* 002 * Copyright 2010-2015 JetBrains s.r.o. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017 package org.jetbrains.kotlin.storage; 018 019 import kotlin.Unit; 020 import kotlin.jvm.functions.Function0; 021 import kotlin.jvm.functions.Function1; 022 import org.jetbrains.annotations.NotNull; 023 import org.jetbrains.annotations.Nullable; 024 import org.jetbrains.kotlin.utils.ExceptionUtilsKt; 025 import org.jetbrains.kotlin.utils.WrappedValues; 026 027 import java.util.Arrays; 028 import java.util.List; 029 import java.util.concurrent.ConcurrentHashMap; 030 import java.util.concurrent.ConcurrentMap; 031 import java.util.concurrent.locks.Lock; 032 import java.util.concurrent.locks.ReentrantLock; 033 034 public class LockBasedStorageManager implements StorageManager { 035 public interface ExceptionHandlingStrategy { 036 ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() { 037 @NotNull 038 @Override 039 public RuntimeException handleException(@NotNull Throwable throwable) { 040 throw ExceptionUtilsKt.rethrow(throwable); 041 } 042 }; 043 044 /* 045 * The signature of this method is a trick: it is used as 046 * 047 * throw strategy.handleException(...) 048 * 049 * most implementations of this method throw exceptions themselves, so it does not matter what they return 050 */ 051 @NotNull 052 RuntimeException handleException(@NotNull Throwable throwable); 053 } 054 055 public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) { 056 @NotNull 057 @Override 058 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 059 return RecursionDetectedResult.fallThrough(); 060 } 061 }; 062 063 @NotNull 064 public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) { 065 return new LockBasedStorageManager(exceptionHandlingStrategy); 066 } 067 068 protected final Lock lock; 069 private final ExceptionHandlingStrategy exceptionHandlingStrategy; 070 private final String debugText; 071 072 private LockBasedStorageManager( 073 @NotNull String debugText, 074 @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy, 075 @NotNull Lock lock 076 ) { 077 this.lock = lock; 078 this.exceptionHandlingStrategy = exceptionHandlingStrategy; 079 this.debugText = debugText; 080 } 081 082 public LockBasedStorageManager() { 083 this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock()); 084 } 085 086 protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) { 087 this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock()); 088 } 089 090 private static String getPointOfConstruction() { 091 StackTraceElement[] trace = Thread.currentThread().getStackTrace(); 092 // we need to skip frames for getStackTrace(), this method and the constructor that's calling it 093 if (trace.length <= 3) return "<unknown creating class>"; 094 return trace[3].toString(); 095 } 096 097 @Override 098 public String toString() { 099 return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")"; 100 } 101 102 @NotNull 103 @Override 104 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) { 105 return createMemoizedFunction(compute, LockBasedStorageManager.<K>createConcurrentHashMap()); 106 } 107 108 @NotNull 109 @Override 110 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction( 111 @NotNull Function1<? super K, ? extends V> compute, 112 @NotNull ConcurrentMap<K, Object> map 113 ) { 114 return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute); 115 } 116 117 @NotNull 118 @Override 119 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) { 120 return createMemoizedFunctionWithNullableValues(compute, LockBasedStorageManager.<K>createConcurrentHashMap()); 121 } 122 123 @Override 124 @NotNull 125 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues( 126 @NotNull Function1<? super K, ? extends V> compute, 127 @NotNull ConcurrentMap<K, Object> map 128 ) { 129 return new MapBasedMemoizedFunction<K, V>(map, compute); 130 } 131 132 @NotNull 133 @Override 134 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) { 135 return new LockBasedNotNullLazyValue<T>(computable); 136 } 137 138 @NotNull 139 @Override 140 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue( 141 @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall 142 ) { 143 return new LockBasedNotNullLazyValue<T>(computable) { 144 @NotNull 145 @Override 146 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 147 return RecursionDetectedResult.value(onRecursiveCall); 148 } 149 }; 150 } 151 152 @NotNull 153 @Override 154 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute( 155 @NotNull Function0<? extends T> computable, 156 final Function1<? super Boolean, ? extends T> onRecursiveCall, 157 @NotNull final Function1<? super T, Unit> postCompute 158 ) { 159 return new LockBasedNotNullLazyValue<T>(computable) { 160 @NotNull 161 @Override 162 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 163 if (onRecursiveCall == null) { 164 return super.recursionDetected(firstTime); 165 } 166 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime)); 167 } 168 169 @Override 170 protected void postCompute(@NotNull T value) { 171 postCompute.invoke(value); 172 } 173 }; 174 } 175 176 @NotNull 177 @Override 178 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) { 179 return new LockBasedLazyValue<T>(computable); 180 } 181 182 @NotNull 183 @Override 184 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) { 185 return new LockBasedLazyValue<T>(computable) { 186 @NotNull 187 @Override 188 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 189 return RecursionDetectedResult.value(onRecursiveCall); 190 } 191 }; 192 } 193 194 @NotNull 195 @Override 196 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute( 197 @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, Unit> postCompute 198 ) { 199 return new LockBasedLazyValue<T>(computable) { 200 @Override 201 protected void postCompute(@Nullable T value) { 202 postCompute.invoke(value); 203 } 204 }; 205 } 206 207 @Override 208 public <T> T compute(@NotNull Function0<? extends T> computable) { 209 lock.lock(); 210 try { 211 return computable.invoke(); 212 } 213 catch (Throwable throwable) { 214 throw exceptionHandlingStrategy.handleException(throwable); 215 } 216 finally { 217 lock.unlock(); 218 } 219 } 220 221 @NotNull 222 private static <K> ConcurrentMap<K, Object> createConcurrentHashMap() { 223 // memory optimization: fewer segments and entries stored 224 return new ConcurrentHashMap<K, Object>(3, 1, 2); 225 } 226 227 @NotNull 228 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 229 throw sanitizeStackTrace(new IllegalStateException("Recursive call in a lazy value under " + this)); 230 } 231 232 private static class RecursionDetectedResult<T> { 233 234 @NotNull 235 public static <T> RecursionDetectedResult<T> value(T value) { 236 return new RecursionDetectedResult<T>(value, false); 237 } 238 239 @NotNull 240 public static <T> RecursionDetectedResult<T> fallThrough() { 241 return new RecursionDetectedResult<T>(null, true); 242 } 243 244 private final T value; 245 private final boolean fallThrough; 246 247 private RecursionDetectedResult(T value, boolean fallThrough) { 248 this.value = value; 249 this.fallThrough = fallThrough; 250 } 251 252 public T getValue() { 253 assert !fallThrough : "A value requested from FALL_THROUGH in " + this; 254 return value; 255 } 256 257 public boolean isFallThrough() { 258 return fallThrough; 259 } 260 261 @Override 262 public String toString() { 263 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value); 264 } 265 } 266 267 private enum NotValue { 268 NOT_COMPUTED, 269 COMPUTING, 270 RECURSION_WAS_DETECTED 271 } 272 273 private class LockBasedLazyValue<T> implements NullableLazyValue<T> { 274 275 private final Function0<? extends T> computable; 276 277 @Nullable 278 private volatile Object value = NotValue.NOT_COMPUTED; 279 280 public LockBasedLazyValue(@NotNull Function0<? extends T> computable) { 281 this.computable = computable; 282 } 283 284 @Override 285 public boolean isComputed() { 286 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING; 287 } 288 289 @Override 290 public boolean isComputing() { 291 return value == NotValue.COMPUTING; 292 } 293 294 @Override 295 public T invoke() { 296 Object _value = value; 297 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 298 299 lock.lock(); 300 try { 301 _value = value; 302 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 303 304 if (_value == NotValue.COMPUTING) { 305 value = NotValue.RECURSION_WAS_DETECTED; 306 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true); 307 if (!result.isFallThrough()) { 308 return result.getValue(); 309 } 310 } 311 312 if (_value == NotValue.RECURSION_WAS_DETECTED) { 313 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false); 314 if (!result.isFallThrough()) { 315 return result.getValue(); 316 } 317 } 318 319 value = NotValue.COMPUTING; 320 try { 321 T typedValue = computable.invoke(); 322 value = typedValue; 323 postCompute(typedValue); 324 return typedValue; 325 } 326 catch (Throwable throwable) { 327 if (value == NotValue.COMPUTING) { 328 // Store only if it's a genuine result, not something thrown through recursionDetected() 329 value = WrappedValues.escapeThrowable(throwable); 330 } 331 throw exceptionHandlingStrategy.handleException(throwable); 332 } 333 } 334 finally { 335 lock.unlock(); 336 } 337 } 338 339 /** 340 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise 341 * @return a value to be returned on a recursive call or subsequent calls 342 */ 343 @NotNull 344 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 345 return recursionDetectedDefault(); 346 } 347 348 protected void postCompute(T value) { 349 // Doing something in post-compute helps prevent infinite recursion 350 } 351 } 352 353 private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> { 354 355 public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) { 356 super(computable); 357 } 358 359 @Override 360 @NotNull 361 public T invoke() { 362 T result = super.invoke(); 363 assert result != null : "compute() returned null"; 364 return result; 365 } 366 } 367 368 private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> { 369 private final ConcurrentMap<K, Object> cache; 370 private final Function1<? super K, ? extends V> compute; 371 372 public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) { 373 this.cache = map; 374 this.compute = compute; 375 } 376 377 @Override 378 @Nullable 379 public V invoke(K input) { 380 Object value = cache.get(input); 381 if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value); 382 383 lock.lock(); 384 try { 385 value = cache.get(input); 386 if (value == NotValue.COMPUTING) { 387 throw recursionDetected(input); 388 } 389 if (value != null) return WrappedValues.unescapeExceptionOrNull(value); 390 391 AssertionError error = null; 392 try { 393 cache.put(input, NotValue.COMPUTING); 394 V typedValue = compute.invoke(input); 395 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue)); 396 397 // This code effectively asserts that oldValue is null 398 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored 399 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that 400 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed 401 if (oldValue != NotValue.COMPUTING) { 402 error = raceCondition(input, oldValue); 403 throw error; 404 } 405 406 return typedValue; 407 } 408 catch (Throwable throwable) { 409 if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable); 410 411 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable)); 412 if (oldValue != NotValue.COMPUTING) { 413 throw raceCondition(input, oldValue); 414 } 415 416 throw exceptionHandlingStrategy.handleException(throwable); 417 } 418 } 419 finally { 420 lock.unlock(); 421 } 422 } 423 424 @NotNull 425 private AssertionError recursionDetected(K input) { 426 return sanitizeStackTrace( 427 new AssertionError("Recursion detected on input: " + input + " under " + LockBasedStorageManager.this) 428 ); 429 } 430 431 @NotNull 432 private AssertionError raceCondition(K input, Object oldValue) { 433 return sanitizeStackTrace( 434 new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue + 435 " under " + LockBasedStorageManager.this) 436 ); 437 } 438 439 @Override 440 public boolean isComputed(K key) { 441 Object value = cache.get(key); 442 return value != null && value != NotValue.COMPUTING; 443 } 444 } 445 446 private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> { 447 448 public MapBasedMemoizedFunctionToNotNull( 449 @NotNull ConcurrentMap<K, Object> map, 450 @NotNull Function1<? super K, ? extends V> compute 451 ) { 452 super(map, compute); 453 } 454 455 @NotNull 456 @Override 457 public V invoke(K input) { 458 V result = super.invoke(input); 459 assert result != null : "compute() returned null under " + LockBasedStorageManager.this; 460 return result; 461 } 462 } 463 464 @NotNull 465 public static LockBasedStorageManager createDelegatingWithSameLock( 466 @NotNull LockBasedStorageManager base, 467 @NotNull ExceptionHandlingStrategy newStrategy 468 ) { 469 return new LockBasedStorageManager(getPointOfConstruction(), newStrategy, base.lock); 470 } 471 472 @NotNull 473 private static <T extends Throwable> T sanitizeStackTrace(@NotNull T throwable) { 474 String storagePackageName = LockBasedStorageManager.class.getPackage().getName(); 475 StackTraceElement[] stackTrace = throwable.getStackTrace(); 476 int size = stackTrace.length; 477 478 int firstNonStorage = -1; 479 for (int i = 0; i < size; i++) { 480 // Skip everything (memoized functions and lazy values) from package org.jetbrains.kotlin.storage 481 if (!stackTrace[i].getClassName().startsWith(storagePackageName)) { 482 firstNonStorage = i; 483 break; 484 } 485 } 486 assert firstNonStorage >= 0 : "This method should only be called on exceptions created in LockBasedStorageManager"; 487 488 List<StackTraceElement> list = Arrays.asList(stackTrace).subList(firstNonStorage, size); 489 throwable.setStackTrace(list.toArray(new StackTraceElement[list.size()])); 490 return throwable; 491 } 492 }