001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.storage;
018    
019    import jet.Function0;
020    import jet.Function1;
021    import jet.Unit;
022    import org.jetbrains.annotations.NotNull;
023    import org.jetbrains.annotations.Nullable;
024    import org.jetbrains.jet.utils.ExceptionUtils;
025    import org.jetbrains.jet.utils.WrappedValues;
026    
027    import java.util.concurrent.ConcurrentHashMap;
028    import java.util.concurrent.ConcurrentMap;
029    import java.util.concurrent.locks.Lock;
030    import java.util.concurrent.locks.ReentrantLock;
031    
032    public class LockBasedStorageManager implements StorageManager {
033    
034        public static final StorageManager NO_LOCKS = new LockBasedStorageManager(NoLock.INSTANCE) {
035            @NotNull
036            @Override
037            protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
038                return RecursionDetectedResult.fallThrough();
039            }
040    
041            @Override
042            public String toString() {
043                return "NO_LOCKS";
044            }
045        };
046    
047        protected final Lock lock;
048    
049        public LockBasedStorageManager() {
050            this(new ReentrantLock());
051        }
052    
053        private LockBasedStorageManager(@NotNull Lock lock) {
054            this.lock = lock;
055        }
056    
057        @NotNull
058        @Override
059        public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<K, V> compute) {
060            return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>());
061        }
062    
063        @NotNull
064        protected  <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
065                @NotNull Function1<K, V> compute,
066                @NotNull ConcurrentMap<K, Object> map
067        ) {
068            return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute);
069        }
070    
071        @NotNull
072        @Override
073        public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<K, V> compute) {
074            return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>());
075        }
076    
077        @NotNull
078        protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
079                @NotNull Function1<K, V> compute,
080                @NotNull ConcurrentMap<K, Object> map
081        ) {
082            return new MapBasedMemoizedFunction<K, V>(map, compute);
083        }
084    
085        @NotNull
086        @Override
087        public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<T> computable) {
088            return new LockBasedNotNullLazyValue<T>(computable);
089        }
090    
091        @NotNull
092        @Override
093        public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
094                @NotNull Function0<T> computable, @NotNull final T onRecursiveCall
095        ) {
096            return new LockBasedNotNullLazyValue<T>(computable) {
097                @NotNull
098                @Override
099                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
100                    return RecursionDetectedResult.value(onRecursiveCall);
101                }
102            };
103        }
104    
105        @NotNull
106        @Override
107        public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
108                @NotNull Function0<T> computable,
109                final Function1<Boolean, T> onRecursiveCall,
110                @NotNull final Function1<T, Unit> postCompute
111        ) {
112            return new LockBasedNotNullLazyValue<T>(computable) {
113                @NotNull
114                @Override
115                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
116                    if (onRecursiveCall == null) {
117                        return super.recursionDetected(firstTime);
118                    }
119                    return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
120                }
121    
122                @Override
123                protected void postCompute(@NotNull T value) {
124                    postCompute.invoke(value);
125                }
126            };
127        }
128    
129        @NotNull
130        @Override
131        public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<T> computable) {
132            return new LockBasedLazyValue<T>(computable);
133        }
134    
135        @NotNull
136        @Override
137        public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<T> computable, final T onRecursiveCall) {
138            return new LockBasedLazyValue<T>(computable) {
139                @NotNull
140                @Override
141                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
142                    return RecursionDetectedResult.value(onRecursiveCall);
143                }
144            };
145        }
146    
147        @NotNull
148        @Override
149        public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
150                @NotNull Function0<T> computable, @NotNull final Function1<T, Unit> postCompute
151        ) {
152            return new LockBasedLazyValue<T>(computable) {
153                @Override
154                protected void postCompute(@Nullable T value) {
155                    postCompute.invoke(value);
156                }
157            };
158        }
159    
160        @Override
161        public <T> T compute(@NotNull Function0<T> computable) {
162            lock.lock();
163            try {
164                return computable.invoke();
165            }
166            finally {
167                lock.unlock();
168            }
169        }
170    
171        @NotNull
172        protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
173            throw new IllegalStateException("Recursive call in a lazy value");
174        }
175    
176        private static class RecursionDetectedResult<T> {
177    
178            @NotNull
179            public static <T> RecursionDetectedResult<T> value(T value) {
180                return new RecursionDetectedResult<T>(value, false);
181            }
182    
183            @NotNull
184            public static <T> RecursionDetectedResult<T> fallThrough() {
185                return new RecursionDetectedResult<T>(null, true);
186            }
187    
188            private final T value;
189            private final boolean fallThrough;
190    
191            private RecursionDetectedResult(T value, boolean fallThrough) {
192                this.value = value;
193                this.fallThrough = fallThrough;
194            }
195    
196            public T getValue() {
197                assert !fallThrough : "A value requested from FALL_THROUGH ";
198                return value;
199            }
200    
201            public boolean isFallThrough() {
202                return fallThrough;
203            }
204    
205            @Override
206            public String toString() {
207                return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
208            }
209        }
210    
211        private enum NotValue {
212            NOT_COMPUTED,
213            COMPUTING,
214            RECURSION_WAS_DETECTED
215        }
216    
217        private class LockBasedLazyValue<T> implements NullableLazyValue<T> {
218    
219            private final Function0<T> computable;
220    
221            @Nullable
222            private volatile Object value = NotValue.NOT_COMPUTED;
223    
224            public LockBasedLazyValue(@NotNull Function0<T> computable) {
225                this.computable = computable;
226            }
227    
228            @Override
229            public boolean isComputed() {
230                return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
231            }
232    
233            @Override
234            public T invoke() {
235                Object _value = value;
236                if (!(value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
237    
238                lock.lock();
239                try {
240                    _value = value;
241                    if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
242    
243                    if (_value == NotValue.COMPUTING) {
244                        value = NotValue.RECURSION_WAS_DETECTED;
245                        RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true);
246                        if (!result.isFallThrough()) {
247                            return result.getValue();
248                        }
249                    }
250    
251                    if (_value == NotValue.RECURSION_WAS_DETECTED) {
252                        RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false);
253                        if (!result.isFallThrough()) {
254                            return result.getValue();
255                        }
256                    }
257    
258                    value = NotValue.COMPUTING;
259                    try {
260                        T typedValue = computable.invoke();
261                        value = typedValue;
262                        postCompute(typedValue);
263                        return typedValue;
264                    }
265                    catch (Throwable throwable) {
266                        if (value == NotValue.COMPUTING) {
267                            // Store only if it's a genuine result, not something thrown through recursionDetected()
268                            value = WrappedValues.escapeThrowable(throwable);
269                        }
270                        throw ExceptionUtils.rethrow(throwable);
271                    }
272                }
273                finally {
274                    lock.unlock();
275                }
276            }
277    
278            /**
279             * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
280             * @return a value to be returned on a recursive call or subsequent calls
281             */
282            @NotNull
283            protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
284                return recursionDetectedDefault();
285            }
286    
287            protected void postCompute(T value) {
288                // Doing something in post-compute helps prevent infinite recursion
289            }
290        }
291    
292        private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
293    
294            public LockBasedNotNullLazyValue(@NotNull Function0<T> computable) {
295                super(computable);
296            }
297    
298            @Override
299            @NotNull
300            public T invoke() {
301                T result = super.invoke();
302                assert result != null : "compute() returned null";
303                return result;
304            }
305        }
306    
307        private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
308            private final ConcurrentMap<K, Object> cache;
309            private final Function1<K, V> compute;
310    
311            public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<K, V> compute) {
312                this.cache = map;
313                this.compute = compute;
314            }
315    
316            @Override
317            @Nullable
318            public V invoke(K input) {
319                Object value = cache.get(input);
320                if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
321    
322                lock.lock();
323                try {
324                    value = cache.get(input);
325                    if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
326    
327                    AssertionError error = null;
328                    try {
329                        V typedValue = compute.invoke(input);
330                        Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
331    
332                        // This code effectively asserts that oldValue is null
333                        // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
334                        // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
335                        // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
336                        if (oldValue != null) {
337                            error = new AssertionError("Race condition or recursion detected. Old value is " + oldValue);
338                            throw error;
339                        }
340    
341                        return typedValue;
342                    }
343                    catch (Throwable throwable) {
344                        if (throwable == error) throw error;
345    
346                        Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
347                        assert oldValue == null : "Race condition or recursion detected. Old value is " + oldValue;
348    
349                        throw ExceptionUtils.rethrow(throwable);
350                    }
351                }
352                finally {
353                    lock.unlock();
354                }
355            }
356        }
357    
358        private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
359    
360            public MapBasedMemoizedFunctionToNotNull(
361                    @NotNull ConcurrentMap<K, Object> map,
362                    @NotNull Function1<K, V> compute
363            ) {
364                super(map, compute);
365            }
366    
367            @NotNull
368            @Override
369            public V invoke(K input) {
370                V result = super.invoke(input);
371                assert result != null : "compute() returned null";
372                return result;
373            }
374        }
375    }