001    /*
002     * Copyright 2010-2015 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.kotlin.js.inline;
018    
019    import com.google.dart.compiler.backend.js.ast.*;
020    import com.google.dart.compiler.backend.js.ast.metadata.MetadataPackage;
021    import com.intellij.psi.PsiElement;
022    import org.jetbrains.annotations.NotNull;
023    import org.jetbrains.annotations.Nullable;
024    import org.jetbrains.kotlin.builtins.InlineStrategy;
025    import org.jetbrains.kotlin.diagnostics.DiagnosticSink;
026    import org.jetbrains.kotlin.js.inline.context.*;
027    import org.jetbrains.kotlin.js.resolve.diagnostics.ErrorsJs;
028    import org.jetbrains.kotlin.js.translate.context.TranslationContext;
029    
030    import java.util.*;
031    
032    import static org.jetbrains.kotlin.js.inline.FunctionInlineMutator.getInlineableCallReplacement;
033    import static org.jetbrains.kotlin.js.inline.clean.CleanPackage.removeUnusedFunctionDefinitions;
034    import static org.jetbrains.kotlin.js.inline.clean.CleanPackage.removeUnusedLocalFunctionDeclarations;
035    import static org.jetbrains.kotlin.js.inline.util.UtilPackage.*;
036    import static org.jetbrains.kotlin.js.translate.utils.JsAstUtils.flattenStatement;
037    
038    public class JsInliner extends JsVisitorWithContextImpl {
039    
040        private final IdentityHashMap<JsName, JsFunction> functions;
041        private final Stack<JsInliningContext> inliningContexts = new Stack<JsInliningContext>();
042        private final Set<JsFunction> processedFunctions = IdentitySet();
043        private final Set<JsFunction> inProcessFunctions = IdentitySet();
044        private final FunctionReader functionReader;
045        private final DiagnosticSink trace;
046    
047        // these are needed for error reporting, when inliner detects cycle
048        private final Stack<JsFunction> namedFunctionsStack = new Stack<JsFunction>();
049        private final LinkedList<JsCallInfo> inlineCallInfos = new LinkedList<JsCallInfo>();
050    
051        /**
052         * A statement can contain more, than one inlineable sub-expressions.
053         * When inline call is expanded, current statement is shifted forward,
054         * but still has same statement context with same index on stack.
055         *
056         * The shifting is intentional, because there could be function literals,
057         * that need to be inlined, after expansion.
058         *
059         * After shifting following inline expansion in the same statement could be
060         * incorrect, because wrong statement index is used.
061         *
062         * To prevent this, after every shift this flag is set to true,
063         * so that visitor wont go deeper until statement is visited.
064         *
065         * Example:
066         *  inline fun f(g: () -> Int): Int { val a = g(); return a }
067         *  inline fun Int.abs(): Int = if (this < 0) -this else this
068         *
069         *  val g = { 10 }
070         *  >> val h = f(g).abs()    // last statement context index
071         *
072         *  val g = { 10 }           // after inline
073         *  >> val f$result          // statement index was not changed
074         *  val a = g()
075         *  f$result = a
076         *  val h = f$result.abs()   // current expression still here; incorrect to inline abs(),
077         *                           //  because statement context on stack point to different statement
078         */
079        private boolean lastStatementWasShifted = false;
080    
081        public static JsProgram process(@NotNull TranslationContext context) {
082            JsProgram program = context.program();
083            IdentityHashMap<JsName, JsFunction> functions = collectNamedFunctions(program);
084            JsInliner inliner = new JsInliner(functions, new FunctionReader(context), context.bindingTrace());
085            inliner.accept(program);
086            removeUnusedFunctionDefinitions(program, functions);
087            return program;
088        }
089    
090        private JsInliner(
091                @NotNull IdentityHashMap<JsName, JsFunction> functions,
092                @NotNull FunctionReader functionReader,
093                @NotNull DiagnosticSink trace
094        ) {
095            this.functions = functions;
096            this.functionReader = functionReader;
097            this.trace = trace;
098        }
099    
100        @Override
101        public boolean visit(JsFunction function, JsContext context) {
102            inliningContexts.push(new JsInliningContext(function));
103            assert !inProcessFunctions.contains(function): "Inliner has revisited function";
104            inProcessFunctions.add(function);
105    
106            if (functions.containsValue(function)) {
107                namedFunctionsStack.push(function);
108            }
109    
110            return super.visit(function, context);
111        }
112    
113        @Override
114        public void endVisit(JsFunction function, JsContext context) {
115            super.endVisit(function, context);
116            refreshLabelNames(getInliningContext().newNamingContext(), function);
117    
118            removeUnusedLocalFunctionDeclarations(function);
119            processedFunctions.add(function);
120    
121            assert inProcessFunctions.contains(function);
122            inProcessFunctions.remove(function);
123    
124            inliningContexts.pop();
125    
126            if (!namedFunctionsStack.empty() && namedFunctionsStack.peek() == function) {
127                namedFunctionsStack.pop();
128            }
129        }
130    
131        @Override
132        public boolean visit(JsInvocation call, JsContext context) {
133            if (shouldInline(call) && canInline(call)) {
134                JsFunction containingFunction = getCurrentNamedFunction();
135                if (containingFunction != null) {
136                    inlineCallInfos.add(new JsCallInfo(call, containingFunction));
137                }
138    
139                JsFunction definition = getFunctionContext().getFunctionDefinition(call);
140    
141                if (inProcessFunctions.contains(definition))  {
142                    reportInlineCycle(call, definition);
143                    return false;
144                }
145    
146                if (!processedFunctions.contains(definition)) {
147                    accept(definition);
148                }
149    
150                inline(call, context);
151            }
152    
153            return !lastStatementWasShifted;
154        }
155    
156        @Override
157        public void endVisit(JsInvocation x, JsContext ctx) {
158            JsCallInfo lastCallInfo = null;
159    
160            if (!inlineCallInfos.isEmpty()) {
161                lastCallInfo = inlineCallInfos.getLast();
162            }
163    
164            if (lastCallInfo != null && lastCallInfo.call == x) {
165                inlineCallInfos.removeLast();
166            }
167        }
168    
169        private void inline(@NotNull JsInvocation call, @NotNull JsContext context) {
170            JsInliningContext inliningContext = getInliningContext();
171            FunctionContext functionContext = getFunctionContext();
172            functionContext.declareFunctionConstructorCalls(call.getArguments());
173            InlineableResult inlineableResult = getInlineableCallReplacement(call, inliningContext);
174    
175            JsStatement inlineableBody = inlineableResult.getInlineableBody();
176            JsExpression resultExpression = inlineableResult.getResultExpression();
177            StatementContext statementContext = inliningContext.getStatementContext();
178            accept(inlineableBody);
179    
180            /**
181             * Assumes, that resultExpression == null, when result is not needed.
182             * @see FunctionInlineMutator.isResultNeeded()
183             */
184            if (resultExpression == null) {
185                statementContext.removeCurrentStatement();
186            } else {
187                context.replaceMe(resultExpression);
188            }
189    
190            /** @see #lastStatementWasShifted */
191            statementContext.shiftCurrentStatementForward();
192            InsertionPoint<JsStatement> insertionPoint = statementContext.getInsertionPoint();
193            insertionPoint.insertAllAfter(flattenStatement(inlineableBody));
194        }
195    
196        /**
197         * Prevents JsInliner from traversing sub-expressions,
198         * when current statement was shifted forward.
199         */
200        @Override
201        protected <T extends JsNode> void doTraverse(T node, JsContext ctx) {
202            if (node instanceof JsStatement) {
203                /** @see #lastStatementWasShifted */
204                lastStatementWasShifted = false;
205            }
206    
207            if (!lastStatementWasShifted) {
208                super.doTraverse(node, ctx);
209            }
210        }
211    
212        @NotNull
213        private JsInliningContext getInliningContext() {
214            return inliningContexts.peek();
215        }
216    
217        @NotNull FunctionContext getFunctionContext() {
218            return getInliningContext().getFunctionContext();
219        }
220    
221        @Nullable
222        private JsFunction getCurrentNamedFunction() {
223            if (namedFunctionsStack.empty()) return null;
224            return namedFunctionsStack.peek();
225        }
226    
227        private void reportInlineCycle(@NotNull JsInvocation call, @NotNull JsFunction calledFunction) {
228            MetadataPackage.setInlineStrategy(call, InlineStrategy.NOT_INLINE);
229            Iterator<JsCallInfo> it = inlineCallInfos.descendingIterator();
230    
231            while (it.hasNext()) {
232                JsCallInfo callInfo = it.next();
233                PsiElement psiElement = MetadataPackage.getPsiElement(callInfo.call);
234    
235                if (psiElement != null) {
236                    trace.report(ErrorsJs.INLINE_CALL_CYCLE.on(psiElement));
237                }
238    
239                if (callInfo.containingFunction == calledFunction) {
240                    break;
241                }
242            }
243        }
244    
245        private boolean canInline(@NotNull JsInvocation call) {
246            FunctionContext functionContext = getFunctionContext();
247            return functionContext.hasFunctionDefinition(call);
248        }
249    
250        private static boolean shouldInline(@NotNull JsInvocation call) {
251            InlineStrategy strategy = MetadataPackage.getInlineStrategy(call);
252            return strategy != null && strategy.isInline();
253        }
254    
255        private class JsInliningContext implements InliningContext {
256            private final FunctionContext functionContext;
257    
258            JsInliningContext(JsFunction function) {
259                functionContext = new FunctionContext(function, this, functionReader) {
260                    @Nullable
261                    @Override
262                    protected JsFunction lookUpStaticFunction(@Nullable JsName functionName) {
263                        return functions.get(functionName);
264                    }
265                };
266            }
267    
268            @NotNull
269            @Override
270            public NamingContext newNamingContext() {
271                JsScope scope = getFunctionContext().getScope();
272                InsertionPoint<JsStatement> insertionPoint = getStatementContext().getInsertionPoint();
273                return new NamingContext(scope, insertionPoint);
274            }
275    
276            @NotNull
277            @Override
278            public StatementContext getStatementContext() {
279                return new StatementContext() {
280                    @NotNull
281                    @Override
282                    public JsContext getCurrentStatementContext() {
283                        return getLastStatementLevelContext();
284                    }
285    
286                    @NotNull
287                    @Override
288                    protected JsStatement getEmptyStatement() {
289                        return getFunctionContext().getEmpty();
290                    }
291    
292                    @Override
293                    public void shiftCurrentStatementForward() {
294                        super.shiftCurrentStatementForward();
295                        lastStatementWasShifted = true;
296                    }
297                };
298            }
299    
300            @NotNull
301            @Override
302            public FunctionContext getFunctionContext() {
303                return functionContext;
304            }
305        }
306    
307        private static class JsCallInfo {
308            @NotNull
309            public final JsInvocation call;
310    
311            @NotNull
312            public final JsFunction containingFunction;
313    
314            private JsCallInfo(@NotNull JsInvocation call, @NotNull JsFunction function) {
315                this.call = call;
316                containingFunction = function;
317            }
318        }
319    }