001 /* 002 * Copyright 2010-2013 JetBrains s.r.o. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017 package org.jetbrains.jet.storage; 018 019 import kotlin.Function0; 020 import kotlin.Function1; 021 import kotlin.Unit; 022 import org.jetbrains.annotations.NotNull; 023 import org.jetbrains.annotations.Nullable; 024 import org.jetbrains.jet.utils.UtilsPackage; 025 import org.jetbrains.jet.utils.WrappedValues; 026 027 import java.util.concurrent.ConcurrentHashMap; 028 import java.util.concurrent.ConcurrentMap; 029 import java.util.concurrent.locks.Lock; 030 import java.util.concurrent.locks.ReentrantLock; 031 032 public class LockBasedStorageManager implements StorageManager { 033 034 public interface ExceptionHandlingStrategy { 035 ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() { 036 @NotNull 037 @Override 038 public RuntimeException handleException(@NotNull Throwable throwable) { 039 throw UtilsPackage.rethrow(throwable); 040 } 041 }; 042 043 /* 044 * The signature of this method is a trick: it is used as 045 * 046 * throw strategy.handleException(...) 047 * 048 * most implementations of this method throw exceptions themselves, so it does not matter what they return 049 */ 050 @NotNull 051 RuntimeException handleException(@NotNull Throwable throwable); 052 } 053 054 public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) { 055 @NotNull 056 @Override 057 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 058 return RecursionDetectedResult.fallThrough(); 059 } 060 }; 061 062 @NotNull 063 public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) { 064 return new LockBasedStorageManager(exceptionHandlingStrategy); 065 } 066 067 protected final Lock lock; 068 private final ExceptionHandlingStrategy exceptionHandlingStrategy; 069 private final String debugText; 070 071 private LockBasedStorageManager( 072 @NotNull String debugText, 073 @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy, 074 @NotNull Lock lock 075 ) { 076 this.lock = lock; 077 this.exceptionHandlingStrategy = exceptionHandlingStrategy; 078 this.debugText = debugText; 079 } 080 081 public LockBasedStorageManager() { 082 this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock()); 083 } 084 085 protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) { 086 this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock()); 087 } 088 089 private static String getPointOfConstruction() { 090 StackTraceElement[] trace = Thread.currentThread().getStackTrace(); 091 // we need to skip frames for getStackTrace(), this method and the constructor that's calling it 092 if (trace.length <= 3) return "<unknown creating class>"; 093 return trace[3].toString(); 094 } 095 096 @Override 097 public String toString() { 098 return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")"; 099 } 100 101 @NotNull 102 @Override 103 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) { 104 return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>()); 105 } 106 107 @NotNull 108 protected <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction( 109 @NotNull Function1<? super K, ? extends V> compute, 110 @NotNull ConcurrentMap<K, Object> map 111 ) { 112 return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute); 113 } 114 115 @NotNull 116 @Override 117 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) { 118 return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>()); 119 } 120 121 @NotNull 122 protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues( 123 @NotNull Function1<? super K, ? extends V> compute, 124 @NotNull ConcurrentMap<K, Object> map 125 ) { 126 return new MapBasedMemoizedFunction<K, V>(map, compute); 127 } 128 129 @NotNull 130 @Override 131 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) { 132 return new LockBasedNotNullLazyValue<T>(computable); 133 } 134 135 @NotNull 136 @Override 137 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue( 138 @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall 139 ) { 140 return new LockBasedNotNullLazyValue<T>(computable) { 141 @NotNull 142 @Override 143 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 144 return RecursionDetectedResult.value(onRecursiveCall); 145 } 146 }; 147 } 148 149 @NotNull 150 @Override 151 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute( 152 @NotNull Function0<? extends T> computable, 153 final Function1<? super Boolean, ? extends T> onRecursiveCall, 154 @NotNull final Function1<? super T, ? extends Unit> postCompute 155 ) { 156 return new LockBasedNotNullLazyValue<T>(computable) { 157 @NotNull 158 @Override 159 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 160 if (onRecursiveCall == null) { 161 return super.recursionDetected(firstTime); 162 } 163 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime)); 164 } 165 166 @Override 167 protected void postCompute(@NotNull T value) { 168 postCompute.invoke(value); 169 } 170 }; 171 } 172 173 @NotNull 174 @Override 175 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) { 176 return new LockBasedLazyValue<T>(computable); 177 } 178 179 @NotNull 180 @Override 181 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) { 182 return new LockBasedLazyValue<T>(computable) { 183 @NotNull 184 @Override 185 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 186 return RecursionDetectedResult.value(onRecursiveCall); 187 } 188 }; 189 } 190 191 @NotNull 192 @Override 193 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute( 194 @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, ? extends Unit> postCompute 195 ) { 196 return new LockBasedLazyValue<T>(computable) { 197 @Override 198 protected void postCompute(@Nullable T value) { 199 postCompute.invoke(value); 200 } 201 }; 202 } 203 204 @Override 205 public <T> T compute(@NotNull Function0<? extends T> computable) { 206 lock.lock(); 207 try { 208 return computable.invoke(); 209 } 210 finally { 211 lock.unlock(); 212 } 213 } 214 215 @NotNull 216 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 217 throw new IllegalStateException("Recursive call in a lazy value under " + this); 218 } 219 220 private static class RecursionDetectedResult<T> { 221 222 @NotNull 223 public static <T> RecursionDetectedResult<T> value(T value) { 224 return new RecursionDetectedResult<T>(value, false); 225 } 226 227 @NotNull 228 public static <T> RecursionDetectedResult<T> fallThrough() { 229 return new RecursionDetectedResult<T>(null, true); 230 } 231 232 private final T value; 233 private final boolean fallThrough; 234 235 private RecursionDetectedResult(T value, boolean fallThrough) { 236 this.value = value; 237 this.fallThrough = fallThrough; 238 } 239 240 public T getValue() { 241 assert !fallThrough : "A value requested from FALL_THROUGH in " + this; 242 return value; 243 } 244 245 public boolean isFallThrough() { 246 return fallThrough; 247 } 248 249 @Override 250 public String toString() { 251 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value); 252 } 253 } 254 255 private enum NotValue { 256 NOT_COMPUTED, 257 COMPUTING, 258 RECURSION_WAS_DETECTED 259 } 260 261 private class LockBasedLazyValue<T> implements NullableLazyValue<T> { 262 263 private final Function0<? extends T> computable; 264 265 @Nullable 266 private volatile Object value = NotValue.NOT_COMPUTED; 267 268 public LockBasedLazyValue(@NotNull Function0<? extends T> computable) { 269 this.computable = computable; 270 } 271 272 @Override 273 public boolean isComputed() { 274 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING; 275 } 276 277 @Override 278 public T invoke() { 279 Object _value = value; 280 if (!(value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 281 282 lock.lock(); 283 try { 284 _value = value; 285 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 286 287 if (_value == NotValue.COMPUTING) { 288 value = NotValue.RECURSION_WAS_DETECTED; 289 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true); 290 if (!result.isFallThrough()) { 291 return result.getValue(); 292 } 293 } 294 295 if (_value == NotValue.RECURSION_WAS_DETECTED) { 296 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false); 297 if (!result.isFallThrough()) { 298 return result.getValue(); 299 } 300 } 301 302 value = NotValue.COMPUTING; 303 try { 304 T typedValue = computable.invoke(); 305 value = typedValue; 306 postCompute(typedValue); 307 return typedValue; 308 } 309 catch (Throwable throwable) { 310 if (value == NotValue.COMPUTING) { 311 // Store only if it's a genuine result, not something thrown through recursionDetected() 312 value = WrappedValues.escapeThrowable(throwable); 313 } 314 throw exceptionHandlingStrategy.handleException(throwable); 315 } 316 } 317 finally { 318 lock.unlock(); 319 } 320 } 321 322 /** 323 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise 324 * @return a value to be returned on a recursive call or subsequent calls 325 */ 326 @NotNull 327 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 328 return recursionDetectedDefault(); 329 } 330 331 protected void postCompute(T value) { 332 // Doing something in post-compute helps prevent infinite recursion 333 } 334 } 335 336 private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> { 337 338 public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) { 339 super(computable); 340 } 341 342 @Override 343 @NotNull 344 public T invoke() { 345 T result = super.invoke(); 346 assert result != null : "compute() returned null"; 347 return result; 348 } 349 } 350 351 private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> { 352 private final ConcurrentMap<K, Object> cache; 353 private final Function1<? super K, ? extends V> compute; 354 355 public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) { 356 this.cache = map; 357 this.compute = compute; 358 } 359 360 @Override 361 @Nullable 362 public V invoke(K input) { 363 Object value = cache.get(input); 364 if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value); 365 366 lock.lock(); 367 try { 368 value = cache.get(input); 369 assert value != NotValue.COMPUTING : "Recursion detected on input: " + input + " under " + LockBasedStorageManager.this; 370 if (value != null) return WrappedValues.unescapeExceptionOrNull(value); 371 372 AssertionError error = null; 373 try { 374 cache.put(input, NotValue.COMPUTING); 375 V typedValue = compute.invoke(input); 376 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue)); 377 378 // This code effectively asserts that oldValue is null 379 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored 380 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that 381 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed 382 if (oldValue != NotValue.COMPUTING) { 383 error = new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue + 384 " under " + LockBasedStorageManager.this); 385 throw error; 386 } 387 388 return typedValue; 389 } 390 catch (Throwable throwable) { 391 if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable); 392 393 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable)); 394 assert oldValue == NotValue.COMPUTING : "Race condition detected on input " + input + ". Old value is " + oldValue + 395 " under " + LockBasedStorageManager.this; 396 397 throw exceptionHandlingStrategy.handleException(throwable); 398 } 399 } 400 finally { 401 lock.unlock(); 402 } 403 } 404 } 405 406 private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> { 407 408 public MapBasedMemoizedFunctionToNotNull( 409 @NotNull ConcurrentMap<K, Object> map, 410 @NotNull Function1<? super K, ? extends V> compute 411 ) { 412 super(map, compute); 413 } 414 415 @NotNull 416 @Override 417 public V invoke(K input) { 418 V result = super.invoke(input); 419 assert result != null : "compute() returned null under " + LockBasedStorageManager.this; 420 return result; 421 } 422 } 423 424 }