001package io.avaje.inject.test;
002
003import io.avaje.inject.BeanScope;
004import io.avaje.inject.BeanScopeBuilder;
005import org.junit.jupiter.api.extension.AfterEachCallback;
006import org.junit.jupiter.api.extension.BeforeAllCallback;
007import org.junit.jupiter.api.extension.BeforeEachCallback;
008import org.junit.jupiter.api.extension.ExtensionContext;
009import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
010import org.slf4j.Logger;
011import org.slf4j.LoggerFactory;
012
013import java.util.ArrayList;
014import java.util.Iterator;
015import java.util.List;
016import java.util.ServiceLoader;
017import java.util.concurrent.locks.ReentrantLock;
018
019/**
020 * Junit 5 extension for avaje inject.
021 * <p>
022 * Supports injection for fields annotated with <code>@Mock, @Spy, @Captor, @Inject</code>.
023 */
024public class InjectExtension implements BeforeAllCallback, BeforeEachCallback, AfterEachCallback, ExtensionContext.Store.CloseableResource {
025
026  private static final Logger log = LoggerFactory.getLogger(InjectExtension.class);
027  private static final Namespace INJECT_NS = Namespace.create("io.avaje.inject.InjectTest");
028  private static final String BEAN_SCOPE = "BEAN_SCOPE";
029  private static final ReentrantLock lock = new ReentrantLock();
030  private static boolean started;
031  private static BeanScope globalTestScope;
032
033  @Override
034  public void beforeAll(ExtensionContext context) {
035    lock.lock();
036    try {
037      if (!started) {
038        initialiseGlobalTestScope(context);
039        started = true;
040      }
041    } finally {
042      lock.unlock();
043    }
044  }
045
046  @Override
047  public void close() throws Throwable {
048    lock.lock();
049    try {
050      if (globalTestScope != null) {
051        log.debug("Closing global test BeanScope");
052        globalTestScope.close();
053      }
054    } finally {
055      lock.unlock();
056    }
057  }
058
059  private void initialiseGlobalTestScope(ExtensionContext context) {
060    Iterator<TestModule> iterator = ServiceLoader.load(TestModule.class).iterator();
061    if (iterator.hasNext()) {
062      log.debug("Building global test BeanScope (as parent scope for tests)");
063      globalTestScope = BeanScope.newBuilder()
064        .withModules(iterator.next())
065        .build();
066
067      log.trace("register global test BeanScope with beans {}", globalTestScope);
068      context.getRoot().getStore(Namespace.GLOBAL).put(InjectExtension.class.getCanonicalName(), this);
069    }
070  }
071
072  /**
073   * Callback that is invoked <em>before</em> each test is invoked.
074   */
075  @Override
076  public void beforeEach(final ExtensionContext context) {
077    final List<MetaReader> readers = createMetaReaders(context);
078
079    final BeanScopeBuilder builder = BeanScope.newBuilder();
080    if (globalTestScope != null) {
081      builder.withParent(globalTestScope, false);
082    }
083    // register mocks and spies local to this test
084    for (MetaReader reader : readers) {
085      reader.build(builder);
086    }
087
088    // wire with local mocks, spies, and globalTestScope
089    final BeanScope beanScope = builder.build();
090    for (MetaReader reader : readers) {
091      reader.setFromScope(beanScope);
092    }
093    log.trace("test setup with {}", readers);
094    context.getStore(INJECT_NS).put(BEAN_SCOPE, beanScope);
095  }
096
097  /**
098   * Return the list of MetaReaders - 1 per test instance.
099   */
100  private List<MetaReader> createMetaReaders(ExtensionContext context) {
101    final List<Object> testInstances = context.getRequiredTestInstances().getAllInstances();
102    final List<MetaReader> readers = new ArrayList<>(testInstances.size());
103    for (Object testInstance : testInstances) {
104      readers.add(new MetaReader(testInstance));
105    }
106    return readers;
107  }
108
109  /**
110   * Callback that is invoked <em>after</em> each test has been invoked.
111   */
112  @Override
113  public void afterEach(ExtensionContext context) {
114    final BeanScope beanScope = (BeanScope) context.getStore(INJECT_NS).remove(BEAN_SCOPE);
115    if (beanScope != null) {
116      beanScope.close();
117    }
118  }
119
120}