001    /*
002     * Copyright 2010-2015 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.kotlin.storage;
018    
019    import kotlin.Function0;
020    import kotlin.Function1;
021    import kotlin.Unit;
022    import org.jetbrains.annotations.NotNull;
023    import org.jetbrains.annotations.Nullable;
024    import org.jetbrains.kotlin.utils.UtilsPackage;
025    import org.jetbrains.kotlin.utils.WrappedValues;
026    
027    import java.util.concurrent.ConcurrentHashMap;
028    import java.util.concurrent.ConcurrentMap;
029    import java.util.concurrent.locks.Lock;
030    import java.util.concurrent.locks.ReentrantLock;
031    
032    public class LockBasedStorageManager implements StorageManager {
033        public interface ExceptionHandlingStrategy {
034            ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() {
035                @NotNull
036                @Override
037                public RuntimeException handleException(@NotNull Throwable throwable) {
038                    throw UtilsPackage.rethrow(throwable);
039                }
040            };
041    
042            /*
043             * The signature of this method is a trick: it is used as
044             *
045             *     throw strategy.handleException(...)
046             *
047             * most implementations of this method throw exceptions themselves, so it does not matter what they return
048             */
049            @NotNull
050            RuntimeException handleException(@NotNull Throwable throwable);
051        }
052    
053        public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) {
054            @NotNull
055            @Override
056            protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
057                return RecursionDetectedResult.fallThrough();
058            }
059        };
060    
061        @NotNull
062        public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
063            return new LockBasedStorageManager(exceptionHandlingStrategy);
064        }
065    
066        protected final Lock lock;
067        private final ExceptionHandlingStrategy exceptionHandlingStrategy;
068        private final String debugText;
069    
070        private LockBasedStorageManager(
071                @NotNull String debugText,
072                @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
073                @NotNull Lock lock
074        ) {
075            this.lock = lock;
076            this.exceptionHandlingStrategy = exceptionHandlingStrategy;
077            this.debugText = debugText;
078        }
079    
080        public LockBasedStorageManager() {
081            this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock());
082        }
083    
084        protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
085            this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock());
086        }
087    
088        private static String getPointOfConstruction() {
089            StackTraceElement[] trace = Thread.currentThread().getStackTrace();
090            // we need to skip frames for getStackTrace(), this method and the constructor that's calling it
091            if (trace.length <= 3) return "<unknown creating class>";
092            return trace[3].toString();
093        }
094    
095        @Override
096        public String toString() {
097            return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")";
098        }
099    
100        @NotNull
101        @Override
102        public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) {
103            return createMemoizedFunction(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
104        }
105    
106        @NotNull
107        @Override
108        public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
109                @NotNull Function1<? super K, ? extends V> compute,
110                @NotNull ConcurrentMap<K, Object> map
111        ) {
112            return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute);
113        }
114    
115        @NotNull
116        @Override
117        public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
118            return createMemoizedFunctionWithNullableValues(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
119        }
120    
121        @Override
122        @NotNull
123        public  <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
124                @NotNull Function1<? super K, ? extends V> compute,
125                @NotNull ConcurrentMap<K, Object> map
126        ) {
127            return new MapBasedMemoizedFunction<K, V>(map, compute);
128        }
129    
130        @NotNull
131        @Override
132        public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) {
133            return new LockBasedNotNullLazyValue<T>(computable);
134        }
135    
136        @NotNull
137        @Override
138        public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
139                @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall
140        ) {
141            return new LockBasedNotNullLazyValue<T>(computable) {
142                @NotNull
143                @Override
144                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
145                    return RecursionDetectedResult.value(onRecursiveCall);
146                }
147            };
148        }
149    
150        @NotNull
151        @Override
152        public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
153                @NotNull Function0<? extends T> computable,
154                final Function1<? super Boolean, ? extends T> onRecursiveCall,
155                @NotNull final Function1<? super T, ? extends Unit> postCompute
156        ) {
157            return new LockBasedNotNullLazyValue<T>(computable) {
158                @NotNull
159                @Override
160                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
161                    if (onRecursiveCall == null) {
162                        return super.recursionDetected(firstTime);
163                    }
164                    return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
165                }
166    
167                @Override
168                protected void postCompute(@NotNull T value) {
169                    postCompute.invoke(value);
170                }
171            };
172        }
173    
174        @NotNull
175        @Override
176        public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) {
177            return new LockBasedLazyValue<T>(computable);
178        }
179    
180        @NotNull
181        @Override
182        public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) {
183            return new LockBasedLazyValue<T>(computable) {
184                @NotNull
185                @Override
186                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
187                    return RecursionDetectedResult.value(onRecursiveCall);
188                }
189            };
190        }
191    
192        @NotNull
193        @Override
194        public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
195                @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, ? extends Unit> postCompute
196        ) {
197            return new LockBasedLazyValue<T>(computable) {
198                @Override
199                protected void postCompute(@Nullable T value) {
200                    postCompute.invoke(value);
201                }
202            };
203        }
204    
205        @Override
206        public <T> T compute(@NotNull Function0<? extends T> computable) {
207            lock.lock();
208            try {
209                return computable.invoke();
210            }
211            catch (Throwable throwable) {
212                throw exceptionHandlingStrategy.handleException(throwable);
213            }
214            finally {
215                lock.unlock();
216            }
217        }
218    
219        @NotNull
220        private static <K> ConcurrentMap<K, Object> createConcurrentHashMap() {
221            // memory optimization: fewer segments and entries stored
222            return new ConcurrentHashMap<K, Object>(3, 1, 2);
223        }
224    
225        @NotNull
226        protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
227            throw new IllegalStateException("Recursive call in a lazy value under " + this);
228        }
229    
230        private static class RecursionDetectedResult<T> {
231    
232            @NotNull
233            public static <T> RecursionDetectedResult<T> value(T value) {
234                return new RecursionDetectedResult<T>(value, false);
235            }
236    
237            @NotNull
238            public static <T> RecursionDetectedResult<T> fallThrough() {
239                return new RecursionDetectedResult<T>(null, true);
240            }
241    
242            private final T value;
243            private final boolean fallThrough;
244    
245            private RecursionDetectedResult(T value, boolean fallThrough) {
246                this.value = value;
247                this.fallThrough = fallThrough;
248            }
249    
250            public T getValue() {
251                assert !fallThrough : "A value requested from FALL_THROUGH in " + this;
252                return value;
253            }
254    
255            public boolean isFallThrough() {
256                return fallThrough;
257            }
258    
259            @Override
260            public String toString() {
261                return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
262            }
263        }
264    
265        private enum NotValue {
266            NOT_COMPUTED,
267            COMPUTING,
268            RECURSION_WAS_DETECTED
269        }
270    
271        private class LockBasedLazyValue<T> implements NullableLazyValue<T> {
272    
273            private final Function0<? extends T> computable;
274    
275            @Nullable
276            private volatile Object value = NotValue.NOT_COMPUTED;
277    
278            public LockBasedLazyValue(@NotNull Function0<? extends T> computable) {
279                this.computable = computable;
280            }
281    
282            @Override
283            public boolean isComputed() {
284                return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
285            }
286    
287            @Override
288            public T invoke() {
289                Object _value = value;
290                if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
291    
292                lock.lock();
293                try {
294                    _value = value;
295                    if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
296    
297                    if (_value == NotValue.COMPUTING) {
298                        value = NotValue.RECURSION_WAS_DETECTED;
299                        RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true);
300                        if (!result.isFallThrough()) {
301                            return result.getValue();
302                        }
303                    }
304    
305                    if (_value == NotValue.RECURSION_WAS_DETECTED) {
306                        RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false);
307                        if (!result.isFallThrough()) {
308                            return result.getValue();
309                        }
310                    }
311    
312                    value = NotValue.COMPUTING;
313                    try {
314                        T typedValue = computable.invoke();
315                        value = typedValue;
316                        postCompute(typedValue);
317                        return typedValue;
318                    }
319                    catch (Throwable throwable) {
320                        if (value == NotValue.COMPUTING) {
321                            // Store only if it's a genuine result, not something thrown through recursionDetected()
322                            value = WrappedValues.escapeThrowable(throwable);
323                        }
324                        throw exceptionHandlingStrategy.handleException(throwable);
325                    }
326                }
327                finally {
328                    lock.unlock();
329                }
330            }
331    
332            /**
333             * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
334             * @return a value to be returned on a recursive call or subsequent calls
335             */
336            @NotNull
337            protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
338                return recursionDetectedDefault();
339            }
340    
341            protected void postCompute(T value) {
342                // Doing something in post-compute helps prevent infinite recursion
343            }
344        }
345    
346        private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
347    
348            public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) {
349                super(computable);
350            }
351    
352            @Override
353            @NotNull
354            public T invoke() {
355                T result = super.invoke();
356                assert result != null : "compute() returned null";
357                return result;
358            }
359        }
360    
361        private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
362            private final ConcurrentMap<K, Object> cache;
363            private final Function1<? super K, ? extends V> compute;
364    
365            public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) {
366                this.cache = map;
367                this.compute = compute;
368            }
369    
370            @Override
371            @Nullable
372            public V invoke(K input) {
373                Object value = cache.get(input);
374                if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value);
375    
376                lock.lock();
377                try {
378                    value = cache.get(input);
379                    assert value != NotValue.COMPUTING : "Recursion detected on input: " + input + " under " + LockBasedStorageManager.this;
380                    if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
381    
382                    AssertionError error = null;
383                    try {
384                        cache.put(input, NotValue.COMPUTING);
385                        V typedValue = compute.invoke(input);
386                        Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
387    
388                        // This code effectively asserts that oldValue is null
389                        // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
390                        // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
391                        // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
392                        if (oldValue != NotValue.COMPUTING) {
393                            error = new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue +
394                                                       " under " + LockBasedStorageManager.this);
395                            throw error;
396                        }
397    
398                        return typedValue;
399                    }
400                    catch (Throwable throwable) {
401                        if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable);
402    
403                        Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
404                        assert oldValue == NotValue.COMPUTING : "Race condition detected on input " + input + ". Old value is " + oldValue +
405                                                                " under " + LockBasedStorageManager.this;
406    
407                        throw exceptionHandlingStrategy.handleException(throwable);
408                    }
409                }
410                finally {
411                    lock.unlock();
412                }
413            }
414        }
415    
416        private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
417    
418            public MapBasedMemoizedFunctionToNotNull(
419                    @NotNull ConcurrentMap<K, Object> map,
420                    @NotNull Function1<? super K, ? extends V> compute
421            ) {
422                super(map, compute);
423            }
424    
425            @NotNull
426            @Override
427            public V invoke(K input) {
428                V result = super.invoke(input);
429                assert result != null : "compute() returned null under " + LockBasedStorageManager.this;
430                return result;
431            }
432        }
433    
434        @NotNull
435        public static LockBasedStorageManager createDelegatingWithSameLock(
436                @NotNull LockBasedStorageManager base,
437                @NotNull ExceptionHandlingStrategy newStrategy
438        ) {
439            return new LockBasedStorageManager(getPointOfConstruction(), newStrategy, base.lock);
440        }
441    }