001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.storage;
018    
019    import jet.Function0;
020    import jet.Function1;
021    import jet.Unit;
022    import org.jetbrains.annotations.NotNull;
023    import org.jetbrains.annotations.Nullable;
024    import org.jetbrains.jet.utils.ExceptionUtils;
025    import org.jetbrains.jet.utils.WrappedValues;
026    
027    import java.util.concurrent.ConcurrentHashMap;
028    import java.util.concurrent.ConcurrentMap;
029    import java.util.concurrent.locks.Lock;
030    import java.util.concurrent.locks.ReentrantLock;
031    
032    public class LockBasedStorageManager implements StorageManager {
033    
034        public static final StorageManager NO_LOCKS = new LockBasedStorageManager(NoLock.INSTANCE) {
035            @Override
036            public String toString() {
037                return "NO_LOCKS";
038            }
039        };
040    
041        protected final Lock lock;
042    
043        public LockBasedStorageManager() {
044            this(new ReentrantLock());
045        }
046    
047        private LockBasedStorageManager(@NotNull Lock lock) {
048            this.lock = lock;
049        }
050    
051        @NotNull
052        @Override
053        public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<K, V> compute) {
054            return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>());
055        }
056    
057        @NotNull
058        protected  <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
059                @NotNull Function1<K, V> compute,
060                @NotNull ConcurrentMap<K, Object> map
061        ) {
062            return new MapBasedMemoizedFunctionToNotNull<K, V>(lock, map, compute);
063        }
064    
065        @NotNull
066        @Override
067        public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<K, V> compute) {
068            return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>());
069        }
070    
071        @NotNull
072        protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
073                @NotNull Function1<K, V> compute,
074                @NotNull ConcurrentMap<K, Object> map
075        ) {
076            return new MapBasedMemoizedFunction<K, V>(lock, map, compute);
077        }
078    
079        @NotNull
080        @Override
081        public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<T> computable) {
082            return new LockBasedNotNullLazyValue<T>(lock, computable);
083        }
084    
085        @NotNull
086        @Override
087        public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
088                @NotNull Function0<T> computable, @NotNull final T onRecursiveCall
089        ) {
090            return new LockBasedNotNullLazyValue<T>(lock, computable) {
091                @Override
092                protected T recursionDetected(boolean firstTime) {
093                    return onRecursiveCall;
094                }
095            };
096        }
097    
098        @NotNull
099        @Override
100        public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
101                @NotNull Function0<T> computable,
102                final Function1<Boolean, T> onRecursiveCall,
103                @NotNull final Function1<T, Unit> postCompute
104        ) {
105            return new LockBasedNotNullLazyValue<T>(lock, computable) {
106                @Nullable
107                @Override
108                protected T recursionDetected(boolean firstTime) {
109                    if (onRecursiveCall == null) {
110                        return super.recursionDetected(firstTime);
111                    }
112                    return onRecursiveCall.invoke(firstTime);
113                }
114    
115                @Override
116                protected void postCompute(@NotNull T value) {
117                    postCompute.invoke(value);
118                }
119            };
120        }
121    
122        @NotNull
123        @Override
124        public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<T> computable) {
125            return new LockBasedLazyValue<T>(lock, computable);
126        }
127    
128        @NotNull
129        @Override
130        public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<T> computable, final T onRecursiveCall) {
131            return new LockBasedLazyValue<T>(lock, computable) {
132                @Override
133                protected T recursionDetected(boolean firstTime) {
134                    return onRecursiveCall;
135                }
136            };
137        }
138    
139        @NotNull
140        @Override
141        public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
142                @NotNull Function0<T> computable, @NotNull final Function1<T, Unit> postCompute
143        ) {
144            return new LockBasedLazyValue<T>(lock, computable) {
145                @Override
146                protected void postCompute(@Nullable T value) {
147                    postCompute.invoke(value);
148                }
149            };
150        }
151    
152        @Override
153        public <T> T compute(@NotNull Function0<T> computable) {
154            lock.lock();
155            try {
156                return computable.invoke();
157            }
158            finally {
159                lock.unlock();
160            }
161        }
162    
163        private static class LockBasedLazyValue<T> implements NullableLazyValue<T> {
164    
165            private enum NotValue {
166                NOT_COMPUTED,
167                COMPUTING,
168                RECURSION_WAS_DETECTED
169            }
170    
171            private final Lock lock;
172            private final Function0<T> computable;
173    
174            @Nullable
175            private volatile Object value = NotValue.NOT_COMPUTED;
176    
177            public LockBasedLazyValue(@NotNull Lock lock, @NotNull Function0<T> computable) {
178                this.lock = lock;
179                this.computable = computable;
180            }
181    
182            @Override
183            public boolean isComputed() {
184                return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
185            }
186    
187            @Override
188            public T invoke() {
189                Object _value = value;
190                if (!(value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
191    
192                lock.lock();
193                try {
194                    _value = value;
195                    if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
196    
197                    if (_value == NotValue.COMPUTING) {
198                        value = NotValue.RECURSION_WAS_DETECTED;
199                        return recursionDetected(/*firstTime = */ true);
200                    }
201    
202                    if (_value == NotValue.RECURSION_WAS_DETECTED) {
203                        return recursionDetected(/*firstTime = */ false);
204                    }
205    
206                    value = NotValue.COMPUTING;
207                    try {
208                        T typedValue = computable.invoke();
209                        value = typedValue;
210                        postCompute(typedValue);
211                        return typedValue;
212                    }
213                    catch (Throwable throwable) {
214                        if (value == NotValue.COMPUTING) {
215                            // Store only if it's a genuine result, not something thrown through recursionDetected()
216                            value = WrappedValues.escapeThrowable(throwable);
217                        }
218                        throw ExceptionUtils.rethrow(throwable);
219                    }
220                }
221                finally {
222                    lock.unlock();
223                }
224            }
225    
226            /**
227             * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
228             * @return a value to be returned on a recursive call or subsequent calls
229             */
230            @Nullable
231            protected T recursionDetected(boolean firstTime) {
232                throw new IllegalStateException("Recursive call in a lazy value");
233            }
234    
235            protected void postCompute(T value) {
236                // Doing something in post-compute helps prevent infinite recursion
237            }
238        }
239    
240        private static class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
241    
242            public LockBasedNotNullLazyValue(@NotNull Lock lock, @NotNull Function0<T> computable) {
243                super(lock, computable);
244            }
245    
246            @Override
247            @NotNull
248            public T invoke() {
249                T result = super.invoke();
250                assert result != null : "compute() returned null";
251                return result;
252            }
253        }
254    
255        private static class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
256            private final Lock lock;
257            private final ConcurrentMap<K, Object> cache;
258            private final Function1<K, V> compute;
259    
260            public MapBasedMemoizedFunction(@NotNull Lock lock, @NotNull ConcurrentMap<K, Object> map, @NotNull Function1<K, V> compute) {
261                this.lock = lock;
262                this.cache = map;
263                this.compute = compute;
264            }
265    
266            @Override
267            @Nullable
268            public V invoke(@NotNull K input) {
269                Object value = cache.get(input);
270                if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
271    
272                lock.lock();
273                try {
274                    value = cache.get(input);
275                    if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
276    
277                    try {
278                        V typedValue = compute.invoke(input);
279                        Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
280                        assert oldValue == null : "Race condition detected";
281    
282                        return typedValue;
283                    }
284                    catch (Throwable throwable) {
285                        Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
286                        assert oldValue == null : "Race condition detected";
287    
288                        throw ExceptionUtils.rethrow(throwable);
289                    }
290                }
291                finally {
292                    lock.unlock();
293                }
294            }
295        }
296    
297        private static class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
298    
299            public MapBasedMemoizedFunctionToNotNull(
300                    @NotNull Lock lock,
301                    @NotNull ConcurrentMap<K, Object> map,
302                    @NotNull Function1<K, V> compute
303            ) {
304                super(lock, map, compute);
305            }
306    
307            @NotNull
308            @Override
309            public V invoke(@NotNull K input) {
310                V result = super.invoke(input);
311                assert result != null : "compute() returned null";
312                return result;
313            }
314        }
315    }