001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.storage;
018    
019    import kotlin.Function0;
020    import kotlin.Function1;
021    import kotlin.Unit;
022    import org.jetbrains.annotations.NotNull;
023    import org.jetbrains.annotations.Nullable;
024    import org.jetbrains.jet.utils.UtilsPackage;
025    import org.jetbrains.jet.utils.WrappedValues;
026    
027    import java.util.concurrent.ConcurrentHashMap;
028    import java.util.concurrent.ConcurrentMap;
029    import java.util.concurrent.locks.Lock;
030    import java.util.concurrent.locks.ReentrantLock;
031    
032    public class LockBasedStorageManager implements StorageManager {
033    
034        public interface ExceptionHandlingStrategy {
035            ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() {
036                @NotNull
037                @Override
038                public RuntimeException handleException(@NotNull Throwable throwable) {
039                    throw UtilsPackage.rethrow(throwable);
040                }
041            };
042    
043            /*
044             * The signature of this method is a trick: it is used as
045             *
046             *     throw strategy.handleException(...)
047             *
048             * most implementations of this method throw exceptions themselves, so it does not matter what they return
049             */
050            @NotNull
051            RuntimeException handleException(@NotNull Throwable throwable);
052        }
053    
054        public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) {
055            @NotNull
056            @Override
057            protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
058                return RecursionDetectedResult.fallThrough();
059            }
060        };
061    
062        @NotNull
063        public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
064            return new LockBasedStorageManager(exceptionHandlingStrategy);
065        }
066    
067        protected final Lock lock;
068        private final ExceptionHandlingStrategy exceptionHandlingStrategy;
069        private final String debugText;
070    
071        private LockBasedStorageManager(
072                @NotNull String debugText,
073                @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
074                @NotNull Lock lock
075        ) {
076            this.lock = lock;
077            this.exceptionHandlingStrategy = exceptionHandlingStrategy;
078            this.debugText = debugText;
079        }
080    
081        public LockBasedStorageManager() {
082            this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock());
083        }
084    
085        protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
086            this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock());
087        }
088    
089        private static String getPointOfConstruction() {
090            StackTraceElement[] trace = Thread.currentThread().getStackTrace();
091            // we need to skip frames for getStackTrace(), this method and the constructor that's calling it
092            if (trace.length <= 3) return "<unknown creating class>";
093            return trace[3].toString();
094        }
095    
096        @Override
097        public String toString() {
098            return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")";
099        }
100    
101        @NotNull
102        @Override
103        public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) {
104            return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>());
105        }
106    
107        @NotNull
108        protected <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
109                @NotNull Function1<? super K, ? extends V> compute,
110                @NotNull ConcurrentMap<K, Object> map
111        ) {
112            return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute);
113        }
114    
115        @NotNull
116        @Override
117        public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
118            return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>());
119        }
120    
121        @NotNull
122        protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
123                @NotNull Function1<? super K, ? extends V> compute,
124                @NotNull ConcurrentMap<K, Object> map
125        ) {
126            return new MapBasedMemoizedFunction<K, V>(map, compute);
127        }
128    
129        @NotNull
130        @Override
131        public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) {
132            return new LockBasedNotNullLazyValue<T>(computable);
133        }
134    
135        @NotNull
136        @Override
137        public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
138                @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall
139        ) {
140            return new LockBasedNotNullLazyValue<T>(computable) {
141                @NotNull
142                @Override
143                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
144                    return RecursionDetectedResult.value(onRecursiveCall);
145                }
146            };
147        }
148    
149        @NotNull
150        @Override
151        public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
152                @NotNull Function0<? extends T> computable,
153                final Function1<? super Boolean, ? extends T> onRecursiveCall,
154                @NotNull final Function1<? super T, ? extends Unit> postCompute
155        ) {
156            return new LockBasedNotNullLazyValue<T>(computable) {
157                @NotNull
158                @Override
159                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
160                    if (onRecursiveCall == null) {
161                        return super.recursionDetected(firstTime);
162                    }
163                    return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
164                }
165    
166                @Override
167                protected void postCompute(@NotNull T value) {
168                    postCompute.invoke(value);
169                }
170            };
171        }
172    
173        @NotNull
174        @Override
175        public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) {
176            return new LockBasedLazyValue<T>(computable);
177        }
178    
179        @NotNull
180        @Override
181        public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) {
182            return new LockBasedLazyValue<T>(computable) {
183                @NotNull
184                @Override
185                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
186                    return RecursionDetectedResult.value(onRecursiveCall);
187                }
188            };
189        }
190    
191        @NotNull
192        @Override
193        public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
194                @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, ? extends Unit> postCompute
195        ) {
196            return new LockBasedLazyValue<T>(computable) {
197                @Override
198                protected void postCompute(@Nullable T value) {
199                    postCompute.invoke(value);
200                }
201            };
202        }
203    
204        @Override
205        public <T> T compute(@NotNull Function0<? extends T> computable) {
206            lock.lock();
207            try {
208                return computable.invoke();
209            }
210            catch (Throwable throwable) {
211                throw exceptionHandlingStrategy.handleException(throwable);
212            }
213            finally {
214                lock.unlock();
215            }
216        }
217    
218        @NotNull
219        protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
220            throw new IllegalStateException("Recursive call in a lazy value under " + this);
221        }
222    
223        private static class RecursionDetectedResult<T> {
224    
225            @NotNull
226            public static <T> RecursionDetectedResult<T> value(T value) {
227                return new RecursionDetectedResult<T>(value, false);
228            }
229    
230            @NotNull
231            public static <T> RecursionDetectedResult<T> fallThrough() {
232                return new RecursionDetectedResult<T>(null, true);
233            }
234    
235            private final T value;
236            private final boolean fallThrough;
237    
238            private RecursionDetectedResult(T value, boolean fallThrough) {
239                this.value = value;
240                this.fallThrough = fallThrough;
241            }
242    
243            public T getValue() {
244                assert !fallThrough : "A value requested from FALL_THROUGH in " + this;
245                return value;
246            }
247    
248            public boolean isFallThrough() {
249                return fallThrough;
250            }
251    
252            @Override
253            public String toString() {
254                return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
255            }
256        }
257    
258        private enum NotValue {
259            NOT_COMPUTED,
260            COMPUTING,
261            RECURSION_WAS_DETECTED
262        }
263    
264        private class LockBasedLazyValue<T> implements NullableLazyValue<T> {
265    
266            private final Function0<? extends T> computable;
267    
268            @Nullable
269            private volatile Object value = NotValue.NOT_COMPUTED;
270    
271            public LockBasedLazyValue(@NotNull Function0<? extends T> computable) {
272                this.computable = computable;
273            }
274    
275            @Override
276            public boolean isComputed() {
277                return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
278            }
279    
280            @Override
281            public T invoke() {
282                Object _value = value;
283                if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
284    
285                lock.lock();
286                try {
287                    _value = value;
288                    if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
289    
290                    if (_value == NotValue.COMPUTING) {
291                        value = NotValue.RECURSION_WAS_DETECTED;
292                        RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true);
293                        if (!result.isFallThrough()) {
294                            return result.getValue();
295                        }
296                    }
297    
298                    if (_value == NotValue.RECURSION_WAS_DETECTED) {
299                        RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false);
300                        if (!result.isFallThrough()) {
301                            return result.getValue();
302                        }
303                    }
304    
305                    value = NotValue.COMPUTING;
306                    try {
307                        T typedValue = computable.invoke();
308                        value = typedValue;
309                        postCompute(typedValue);
310                        return typedValue;
311                    }
312                    catch (Throwable throwable) {
313                        if (value == NotValue.COMPUTING) {
314                            // Store only if it's a genuine result, not something thrown through recursionDetected()
315                            value = WrappedValues.escapeThrowable(throwable);
316                        }
317                        throw exceptionHandlingStrategy.handleException(throwable);
318                    }
319                }
320                finally {
321                    lock.unlock();
322                }
323            }
324    
325            /**
326             * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
327             * @return a value to be returned on a recursive call or subsequent calls
328             */
329            @NotNull
330            protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
331                return recursionDetectedDefault();
332            }
333    
334            protected void postCompute(T value) {
335                // Doing something in post-compute helps prevent infinite recursion
336            }
337        }
338    
339        private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
340    
341            public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) {
342                super(computable);
343            }
344    
345            @Override
346            @NotNull
347            public T invoke() {
348                T result = super.invoke();
349                assert result != null : "compute() returned null";
350                return result;
351            }
352        }
353    
354        private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
355            private final ConcurrentMap<K, Object> cache;
356            private final Function1<? super K, ? extends V> compute;
357    
358            public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) {
359                this.cache = map;
360                this.compute = compute;
361            }
362    
363            @Override
364            @Nullable
365            public V invoke(K input) {
366                Object value = cache.get(input);
367                if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value);
368    
369                lock.lock();
370                try {
371                    value = cache.get(input);
372                    assert value != NotValue.COMPUTING : "Recursion detected on input: " + input + " under " + LockBasedStorageManager.this;
373                    if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
374    
375                    AssertionError error = null;
376                    try {
377                        cache.put(input, NotValue.COMPUTING);
378                        V typedValue = compute.invoke(input);
379                        Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
380    
381                        // This code effectively asserts that oldValue is null
382                        // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
383                        // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
384                        // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
385                        if (oldValue != NotValue.COMPUTING) {
386                            error = new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue +
387                                                       " under " + LockBasedStorageManager.this);
388                            throw error;
389                        }
390    
391                        return typedValue;
392                    }
393                    catch (Throwable throwable) {
394                        if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable);
395    
396                        Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
397                        assert oldValue == NotValue.COMPUTING : "Race condition detected on input " + input + ". Old value is " + oldValue +
398                                                                " under " + LockBasedStorageManager.this;
399    
400                        throw exceptionHandlingStrategy.handleException(throwable);
401                    }
402                }
403                finally {
404                    lock.unlock();
405                }
406            }
407        }
408    
409        private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
410    
411            public MapBasedMemoizedFunctionToNotNull(
412                    @NotNull ConcurrentMap<K, Object> map,
413                    @NotNull Function1<? super K, ? extends V> compute
414            ) {
415                super(map, compute);
416            }
417    
418            @NotNull
419            @Override
420            public V invoke(K input) {
421                V result = super.invoke(input);
422                assert result != null : "compute() returned null under " + LockBasedStorageManager.this;
423                return result;
424            }
425        }
426    
427        @NotNull
428        public static LockBasedStorageManager createDelegatingWithSameLock(
429                @NotNull LockBasedStorageManager base,
430                @NotNull ExceptionHandlingStrategy newStrategy
431        ) {
432            return new LockBasedStorageManager(getPointOfConstruction(), newStrategy, base.lock);
433        }
434    }