001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.utils;
018    
019    import com.intellij.util.Function;
020    import com.intellij.util.NullableFunction;
021    import org.jetbrains.annotations.NotNull;
022    import org.jetbrains.annotations.Nullable;
023    
024    import java.util.HashMap;
025    import java.util.Map;
026    
027    public abstract class NullableMemoizedFunction<K, V> implements NullableFunction<K, V> {
028    
029        public static <K, V> NullableFunction<K, V> create(@NotNull final Function<K, V> compute) {
030            return new NullableMemoizedFunction<K, V>() {
031                @Nullable
032                @Override
033                protected V compute(@NotNull K input) {
034                    return compute.fun(input);
035                }
036            };
037        }
038    
039        private final Map<K, Object> cache;
040    
041        public NullableMemoizedFunction(@NotNull Map<K, Object> map) {
042            this.cache = map;
043        }
044    
045        public NullableMemoizedFunction() {
046            this(new HashMap<K, Object>());
047        }
048    
049        @Override
050        @Nullable
051        public V fun(@NotNull K input) {
052            Object value = cache.get(input);
053            if (value != null) return WrappedValues.unescapeNull(value);
054    
055            V typedValue = compute(input);
056    
057            Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
058            assert oldValue == null : "Race condition detected";
059    
060            return typedValue;
061        }
062    
063        @Nullable
064        protected abstract V compute(@NotNull K input);
065    }