001 /*
002 * Copyright 2010-2015 JetBrains s.r.o.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017 package org.jetbrains.kotlin.storage;
018
019 import kotlin.Unit;
020 import kotlin.jvm.functions.Function0;
021 import kotlin.jvm.functions.Function1;
022 import org.jetbrains.annotations.NotNull;
023 import org.jetbrains.annotations.Nullable;
024 import org.jetbrains.kotlin.utils.UtilsPackage;
025 import org.jetbrains.kotlin.utils.WrappedValues;
026
027 import java.util.Arrays;
028 import java.util.List;
029 import java.util.concurrent.ConcurrentHashMap;
030 import java.util.concurrent.ConcurrentMap;
031 import java.util.concurrent.locks.Lock;
032 import java.util.concurrent.locks.ReentrantLock;
033
034 public class LockBasedStorageManager implements StorageManager {
035 public interface ExceptionHandlingStrategy {
036 ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() {
037 @NotNull
038 @Override
039 public RuntimeException handleException(@NotNull Throwable throwable) {
040 throw UtilsPackage.rethrow(throwable);
041 }
042 };
043
044 /*
045 * The signature of this method is a trick: it is used as
046 *
047 * throw strategy.handleException(...)
048 *
049 * most implementations of this method throw exceptions themselves, so it does not matter what they return
050 */
051 @NotNull
052 RuntimeException handleException(@NotNull Throwable throwable);
053 }
054
055 public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) {
056 @NotNull
057 @Override
058 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
059 return RecursionDetectedResult.fallThrough();
060 }
061 };
062
063 @NotNull
064 public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
065 return new LockBasedStorageManager(exceptionHandlingStrategy);
066 }
067
068 protected final Lock lock;
069 private final ExceptionHandlingStrategy exceptionHandlingStrategy;
070 private final String debugText;
071
072 private LockBasedStorageManager(
073 @NotNull String debugText,
074 @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
075 @NotNull Lock lock
076 ) {
077 this.lock = lock;
078 this.exceptionHandlingStrategy = exceptionHandlingStrategy;
079 this.debugText = debugText;
080 }
081
082 public LockBasedStorageManager() {
083 this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock());
084 }
085
086 protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
087 this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock());
088 }
089
090 private static String getPointOfConstruction() {
091 StackTraceElement[] trace = Thread.currentThread().getStackTrace();
092 // we need to skip frames for getStackTrace(), this method and the constructor that's calling it
093 if (trace.length <= 3) return "<unknown creating class>";
094 return trace[3].toString();
095 }
096
097 @Override
098 public String toString() {
099 return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")";
100 }
101
102 @NotNull
103 @Override
104 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) {
105 return createMemoizedFunction(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
106 }
107
108 @NotNull
109 @Override
110 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
111 @NotNull Function1<? super K, ? extends V> compute,
112 @NotNull ConcurrentMap<K, Object> map
113 ) {
114 return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute);
115 }
116
117 @NotNull
118 @Override
119 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
120 return createMemoizedFunctionWithNullableValues(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
121 }
122
123 @Override
124 @NotNull
125 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
126 @NotNull Function1<? super K, ? extends V> compute,
127 @NotNull ConcurrentMap<K, Object> map
128 ) {
129 return new MapBasedMemoizedFunction<K, V>(map, compute);
130 }
131
132 @NotNull
133 @Override
134 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) {
135 return new LockBasedNotNullLazyValue<T>(computable);
136 }
137
138 @NotNull
139 @Override
140 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
141 @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall
142 ) {
143 return new LockBasedNotNullLazyValue<T>(computable) {
144 @NotNull
145 @Override
146 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
147 return RecursionDetectedResult.value(onRecursiveCall);
148 }
149 };
150 }
151
152 @NotNull
153 @Override
154 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
155 @NotNull Function0<? extends T> computable,
156 final Function1<? super Boolean, ? extends T> onRecursiveCall,
157 @NotNull final Function1<? super T, ? extends Unit> postCompute
158 ) {
159 return new LockBasedNotNullLazyValue<T>(computable) {
160 @NotNull
161 @Override
162 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
163 if (onRecursiveCall == null) {
164 return super.recursionDetected(firstTime);
165 }
166 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
167 }
168
169 @Override
170 protected void postCompute(@NotNull T value) {
171 postCompute.invoke(value);
172 }
173 };
174 }
175
176 @NotNull
177 @Override
178 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) {
179 return new LockBasedLazyValue<T>(computable);
180 }
181
182 @NotNull
183 @Override
184 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) {
185 return new LockBasedLazyValue<T>(computable) {
186 @NotNull
187 @Override
188 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
189 return RecursionDetectedResult.value(onRecursiveCall);
190 }
191 };
192 }
193
194 @NotNull
195 @Override
196 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
197 @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, ? extends Unit> postCompute
198 ) {
199 return new LockBasedLazyValue<T>(computable) {
200 @Override
201 protected void postCompute(@Nullable T value) {
202 postCompute.invoke(value);
203 }
204 };
205 }
206
207 @Override
208 public <T> T compute(@NotNull Function0<? extends T> computable) {
209 lock.lock();
210 try {
211 return computable.invoke();
212 }
213 catch (Throwable throwable) {
214 throw exceptionHandlingStrategy.handleException(throwable);
215 }
216 finally {
217 lock.unlock();
218 }
219 }
220
221 @NotNull
222 private static <K> ConcurrentMap<K, Object> createConcurrentHashMap() {
223 // memory optimization: fewer segments and entries stored
224 return new ConcurrentHashMap<K, Object>(3, 1, 2);
225 }
226
227 @NotNull
228 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
229 throw sanitizeStackTrace(new IllegalStateException("Recursive call in a lazy value under " + this));
230 }
231
232 private static class RecursionDetectedResult<T> {
233
234 @NotNull
235 public static <T> RecursionDetectedResult<T> value(T value) {
236 return new RecursionDetectedResult<T>(value, false);
237 }
238
239 @NotNull
240 public static <T> RecursionDetectedResult<T> fallThrough() {
241 return new RecursionDetectedResult<T>(null, true);
242 }
243
244 private final T value;
245 private final boolean fallThrough;
246
247 private RecursionDetectedResult(T value, boolean fallThrough) {
248 this.value = value;
249 this.fallThrough = fallThrough;
250 }
251
252 public T getValue() {
253 assert !fallThrough : "A value requested from FALL_THROUGH in " + this;
254 return value;
255 }
256
257 public boolean isFallThrough() {
258 return fallThrough;
259 }
260
261 @Override
262 public String toString() {
263 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
264 }
265 }
266
267 private enum NotValue {
268 NOT_COMPUTED,
269 COMPUTING,
270 RECURSION_WAS_DETECTED
271 }
272
273 private class LockBasedLazyValue<T> implements NullableLazyValue<T> {
274
275 private final Function0<? extends T> computable;
276
277 @Nullable
278 private volatile Object value = NotValue.NOT_COMPUTED;
279
280 public LockBasedLazyValue(@NotNull Function0<? extends T> computable) {
281 this.computable = computable;
282 }
283
284 @Override
285 public boolean isComputed() {
286 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
287 }
288
289 @Override
290 public T invoke() {
291 Object _value = value;
292 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
293
294 lock.lock();
295 try {
296 _value = value;
297 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
298
299 if (_value == NotValue.COMPUTING) {
300 value = NotValue.RECURSION_WAS_DETECTED;
301 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true);
302 if (!result.isFallThrough()) {
303 return result.getValue();
304 }
305 }
306
307 if (_value == NotValue.RECURSION_WAS_DETECTED) {
308 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false);
309 if (!result.isFallThrough()) {
310 return result.getValue();
311 }
312 }
313
314 value = NotValue.COMPUTING;
315 try {
316 T typedValue = computable.invoke();
317 value = typedValue;
318 postCompute(typedValue);
319 return typedValue;
320 }
321 catch (Throwable throwable) {
322 if (value == NotValue.COMPUTING) {
323 // Store only if it's a genuine result, not something thrown through recursionDetected()
324 value = WrappedValues.escapeThrowable(throwable);
325 }
326 throw exceptionHandlingStrategy.handleException(throwable);
327 }
328 }
329 finally {
330 lock.unlock();
331 }
332 }
333
334 /**
335 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
336 * @return a value to be returned on a recursive call or subsequent calls
337 */
338 @NotNull
339 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
340 return recursionDetectedDefault();
341 }
342
343 protected void postCompute(T value) {
344 // Doing something in post-compute helps prevent infinite recursion
345 }
346 }
347
348 private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
349
350 public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) {
351 super(computable);
352 }
353
354 @Override
355 @NotNull
356 public T invoke() {
357 T result = super.invoke();
358 assert result != null : "compute() returned null";
359 return result;
360 }
361 }
362
363 private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
364 private final ConcurrentMap<K, Object> cache;
365 private final Function1<? super K, ? extends V> compute;
366
367 public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) {
368 this.cache = map;
369 this.compute = compute;
370 }
371
372 @Override
373 @Nullable
374 public V invoke(K input) {
375 Object value = cache.get(input);
376 if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value);
377
378 lock.lock();
379 try {
380 value = cache.get(input);
381 if (value == NotValue.COMPUTING) {
382 throw recursionDetected(input);
383 }
384 if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
385
386 AssertionError error = null;
387 try {
388 cache.put(input, NotValue.COMPUTING);
389 V typedValue = compute.invoke(input);
390 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
391
392 // This code effectively asserts that oldValue is null
393 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
394 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
395 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
396 if (oldValue != NotValue.COMPUTING) {
397 error = raceCondition(input, oldValue);
398 throw error;
399 }
400
401 return typedValue;
402 }
403 catch (Throwable throwable) {
404 if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable);
405
406 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
407 if (oldValue != NotValue.COMPUTING) {
408 throw raceCondition(input, oldValue);
409 }
410
411 throw exceptionHandlingStrategy.handleException(throwable);
412 }
413 }
414 finally {
415 lock.unlock();
416 }
417 }
418
419 @NotNull
420 private AssertionError recursionDetected(K input) {
421 return sanitizeStackTrace(
422 new AssertionError("Recursion detected on input: " + input + " under " + LockBasedStorageManager.this)
423 );
424 }
425
426 @NotNull
427 private AssertionError raceCondition(K input, Object oldValue) {
428 return sanitizeStackTrace(
429 new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue +
430 " under " + LockBasedStorageManager.this)
431 );
432 }
433
434 @Override
435 public boolean isComputed(K key) {
436 Object value = cache.get(key);
437 return value != null && value != NotValue.COMPUTING;
438 }
439 }
440
441 private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
442
443 public MapBasedMemoizedFunctionToNotNull(
444 @NotNull ConcurrentMap<K, Object> map,
445 @NotNull Function1<? super K, ? extends V> compute
446 ) {
447 super(map, compute);
448 }
449
450 @NotNull
451 @Override
452 public V invoke(K input) {
453 V result = super.invoke(input);
454 assert result != null : "compute() returned null under " + LockBasedStorageManager.this;
455 return result;
456 }
457 }
458
459 @NotNull
460 public static LockBasedStorageManager createDelegatingWithSameLock(
461 @NotNull LockBasedStorageManager base,
462 @NotNull ExceptionHandlingStrategy newStrategy
463 ) {
464 return new LockBasedStorageManager(getPointOfConstruction(), newStrategy, base.lock);
465 }
466
467 @NotNull
468 private static <T extends Throwable> T sanitizeStackTrace(@NotNull T throwable) {
469 String storagePackageName = LockBasedStorageManager.class.getPackage().getName();
470 StackTraceElement[] stackTrace = throwable.getStackTrace();
471 int size = stackTrace.length;
472
473 int firstNonStorage = -1;
474 for (int i = 0; i < size; i++) {
475 // Skip everything (memoized functions and lazy values) from package org.jetbrains.kotlin.storage
476 if (!stackTrace[i].getClassName().startsWith(storagePackageName)) {
477 firstNonStorage = i;
478 break;
479 }
480 }
481 assert firstNonStorage >= 0 : "This method should only be called on exceptions created in LockBasedStorageManager";
482
483 List<StackTraceElement> list = Arrays.asList(stackTrace).subList(firstNonStorage, size);
484 throwable.setStackTrace(list.toArray(new StackTraceElement[list.size()]));
485 return throwable;
486 }
487 }