001 /* 002 * Copyright 2010-2013 JetBrains s.r.o. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017 package org.jetbrains.jet.storage; 018 019 import jet.Function0; 020 import jet.Function1; 021 import jet.Unit; 022 import org.jetbrains.annotations.NotNull; 023 import org.jetbrains.annotations.Nullable; 024 import org.jetbrains.jet.utils.ExceptionUtils; 025 import org.jetbrains.jet.utils.WrappedValues; 026 027 import java.util.concurrent.ConcurrentHashMap; 028 import java.util.concurrent.ConcurrentMap; 029 import java.util.concurrent.locks.Lock; 030 import java.util.concurrent.locks.ReentrantLock; 031 032 public class LockBasedStorageManager implements StorageManager { 033 034 public static final StorageManager NO_LOCKS = new LockBasedStorageManager(NoLock.INSTANCE) { 035 @NotNull 036 @Override 037 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 038 return RecursionDetectedResult.fallThrough(); 039 } 040 041 @Override 042 public String toString() { 043 return "NO_LOCKS"; 044 } 045 }; 046 047 protected final Lock lock; 048 049 public LockBasedStorageManager() { 050 this(new ReentrantLock()); 051 } 052 053 private LockBasedStorageManager(@NotNull Lock lock) { 054 this.lock = lock; 055 } 056 057 @NotNull 058 @Override 059 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<K, V> compute) { 060 return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>()); 061 } 062 063 @NotNull 064 protected <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction( 065 @NotNull Function1<K, V> compute, 066 @NotNull ConcurrentMap<K, Object> map 067 ) { 068 return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute); 069 } 070 071 @NotNull 072 @Override 073 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<K, V> compute) { 074 return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>()); 075 } 076 077 @NotNull 078 protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues( 079 @NotNull Function1<K, V> compute, 080 @NotNull ConcurrentMap<K, Object> map 081 ) { 082 return new MapBasedMemoizedFunction<K, V>(map, compute); 083 } 084 085 @NotNull 086 @Override 087 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<T> computable) { 088 return new LockBasedNotNullLazyValue<T>(computable); 089 } 090 091 @NotNull 092 @Override 093 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue( 094 @NotNull Function0<T> computable, @NotNull final T onRecursiveCall 095 ) { 096 return new LockBasedNotNullLazyValue<T>(computable) { 097 @NotNull 098 @Override 099 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 100 return RecursionDetectedResult.value(onRecursiveCall); 101 } 102 }; 103 } 104 105 @NotNull 106 @Override 107 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute( 108 @NotNull Function0<T> computable, 109 final Function1<Boolean, T> onRecursiveCall, 110 @NotNull final Function1<T, Unit> postCompute 111 ) { 112 return new LockBasedNotNullLazyValue<T>(computable) { 113 @NotNull 114 @Override 115 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 116 if (onRecursiveCall == null) { 117 return super.recursionDetected(firstTime); 118 } 119 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime)); 120 } 121 122 @Override 123 protected void postCompute(@NotNull T value) { 124 postCompute.invoke(value); 125 } 126 }; 127 } 128 129 @NotNull 130 @Override 131 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<T> computable) { 132 return new LockBasedLazyValue<T>(computable); 133 } 134 135 @NotNull 136 @Override 137 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<T> computable, final T onRecursiveCall) { 138 return new LockBasedLazyValue<T>(computable) { 139 @NotNull 140 @Override 141 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 142 return RecursionDetectedResult.value(onRecursiveCall); 143 } 144 }; 145 } 146 147 @NotNull 148 @Override 149 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute( 150 @NotNull Function0<T> computable, @NotNull final Function1<T, Unit> postCompute 151 ) { 152 return new LockBasedLazyValue<T>(computable) { 153 @Override 154 protected void postCompute(@Nullable T value) { 155 postCompute.invoke(value); 156 } 157 }; 158 } 159 160 @Override 161 public <T> T compute(@NotNull Function0<T> computable) { 162 lock.lock(); 163 try { 164 return computable.invoke(); 165 } 166 finally { 167 lock.unlock(); 168 } 169 } 170 171 @NotNull 172 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() { 173 throw new IllegalStateException("Recursive call in a lazy value"); 174 } 175 176 private static class RecursionDetectedResult<T> { 177 178 @NotNull 179 public static <T> RecursionDetectedResult<T> value(T value) { 180 return new RecursionDetectedResult<T>(value, false); 181 } 182 183 @NotNull 184 public static <T> RecursionDetectedResult<T> fallThrough() { 185 return new RecursionDetectedResult<T>(null, true); 186 } 187 188 private final T value; 189 private final boolean fallThrough; 190 191 private RecursionDetectedResult(T value, boolean fallThrough) { 192 this.value = value; 193 this.fallThrough = fallThrough; 194 } 195 196 public T getValue() { 197 assert !fallThrough : "A value requested from FALL_THROUGH "; 198 return value; 199 } 200 201 public boolean isFallThrough() { 202 return fallThrough; 203 } 204 205 @Override 206 public String toString() { 207 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value); 208 } 209 } 210 211 private enum NotValue { 212 NOT_COMPUTED, 213 COMPUTING, 214 RECURSION_WAS_DETECTED 215 } 216 217 private class LockBasedLazyValue<T> implements NullableLazyValue<T> { 218 219 private final Function0<T> computable; 220 221 @Nullable 222 private volatile Object value = NotValue.NOT_COMPUTED; 223 224 public LockBasedLazyValue(@NotNull Function0<T> computable) { 225 this.computable = computable; 226 } 227 228 @Override 229 public boolean isComputed() { 230 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING; 231 } 232 233 @Override 234 public T invoke() { 235 Object _value = value; 236 if (!(value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 237 238 lock.lock(); 239 try { 240 _value = value; 241 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value); 242 243 if (_value == NotValue.COMPUTING) { 244 value = NotValue.RECURSION_WAS_DETECTED; 245 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true); 246 if (!result.isFallThrough()) { 247 return result.getValue(); 248 } 249 } 250 251 if (_value == NotValue.RECURSION_WAS_DETECTED) { 252 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false); 253 if (!result.isFallThrough()) { 254 return result.getValue(); 255 } 256 } 257 258 value = NotValue.COMPUTING; 259 try { 260 T typedValue = computable.invoke(); 261 value = typedValue; 262 postCompute(typedValue); 263 return typedValue; 264 } 265 catch (Throwable throwable) { 266 if (value == NotValue.COMPUTING) { 267 // Store only if it's a genuine result, not something thrown through recursionDetected() 268 value = WrappedValues.escapeThrowable(throwable); 269 } 270 throw ExceptionUtils.rethrow(throwable); 271 } 272 } 273 finally { 274 lock.unlock(); 275 } 276 } 277 278 /** 279 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise 280 * @return a value to be returned on a recursive call or subsequent calls 281 */ 282 @NotNull 283 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) { 284 return recursionDetectedDefault(); 285 } 286 287 protected void postCompute(T value) { 288 // Doing something in post-compute helps prevent infinite recursion 289 } 290 } 291 292 private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> { 293 294 public LockBasedNotNullLazyValue(@NotNull Function0<T> computable) { 295 super(computable); 296 } 297 298 @Override 299 @NotNull 300 public T invoke() { 301 T result = super.invoke(); 302 assert result != null : "compute() returned null"; 303 return result; 304 } 305 } 306 307 private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> { 308 private final ConcurrentMap<K, Object> cache; 309 private final Function1<K, V> compute; 310 311 public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<K, V> compute) { 312 this.cache = map; 313 this.compute = compute; 314 } 315 316 @Override 317 @Nullable 318 public V invoke(K input) { 319 Object value = cache.get(input); 320 if (value != null) return WrappedValues.unescapeExceptionOrNull(value); 321 322 lock.lock(); 323 try { 324 value = cache.get(input); 325 if (value != null) return WrappedValues.unescapeExceptionOrNull(value); 326 327 AssertionError error = null; 328 try { 329 V typedValue = compute.invoke(input); 330 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue)); 331 332 // This code effectively asserts that oldValue is null 333 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored 334 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that 335 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed 336 if (oldValue != null) { 337 error = new AssertionError("Race condition or recursion detected. Old value is " + oldValue); 338 throw error; 339 } 340 341 return typedValue; 342 } 343 catch (Throwable throwable) { 344 if (throwable == error) throw error; 345 346 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable)); 347 assert oldValue == null : "Race condition or recursion detected. Old value is " + oldValue; 348 349 throw ExceptionUtils.rethrow(throwable); 350 } 351 } 352 finally { 353 lock.unlock(); 354 } 355 } 356 } 357 358 private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> { 359 360 public MapBasedMemoizedFunctionToNotNull( 361 @NotNull ConcurrentMap<K, Object> map, 362 @NotNull Function1<K, V> compute 363 ) { 364 super(map, compute); 365 } 366 367 @NotNull 368 @Override 369 public V invoke(K input) { 370 V result = super.invoke(input); 371 assert result != null : "compute() returned null"; 372 return result; 373 } 374 } 375 }