001    /*
002     * Copyright 2010-2014 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.codegen.inline;
018    
019    import com.google.common.collect.Lists;
020    import com.intellij.util.ArrayUtil;
021    import org.jetbrains.annotations.NotNull;
022    import org.jetbrains.org.objectweb.asm.Label;
023    import org.jetbrains.org.objectweb.asm.MethodVisitor;
024    import org.jetbrains.org.objectweb.asm.Opcodes;
025    import org.jetbrains.org.objectweb.asm.Type;
026    import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter;
027    import org.jetbrains.org.objectweb.asm.commons.Method;
028    import org.jetbrains.org.objectweb.asm.commons.RemappingMethodAdapter;
029    import org.jetbrains.org.objectweb.asm.tree.*;
030    import org.jetbrains.org.objectweb.asm.tree.analysis.*;
031    import org.jetbrains.jet.codegen.ClosureCodegen;
032    import org.jetbrains.jet.codegen.StackValue;
033    import org.jetbrains.jet.codegen.state.JetTypeMapper;
034    
035    import java.util.*;
036    
037    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isInvokeOnLambda;
038    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isLambdaConstructorCall;
039    
040    public class MethodInliner {
041    
042        private final MethodNode node;
043    
044        private final Parameters parameters;
045    
046        private final InliningContext inliningContext;
047    
048        private final FieldRemapper nodeRemapper;
049    
050        private final boolean isSameModule;
051    
052        private final String errorPrefix;
053    
054        private final JetTypeMapper typeMapper;
055    
056        private final List<InvokeCall> invokeCalls = new ArrayList<InvokeCall>();
057    
058        //keeps order
059        private final List<ConstructorInvocation> constructorInvocations = new ArrayList<ConstructorInvocation>();
060        //current state
061        private final Map<String, String> currentTypeMapping = new HashMap<String, String>();
062    
063        private final InlineResult result;
064    
065        /*
066         *
067         * @param node
068         * @param parameters
069         * @param inliningContext
070         * @param lambdaType - in case on lambda 'invoke' inlining
071         */
072        public MethodInliner(
073                @NotNull MethodNode node,
074                @NotNull Parameters parameters,
075                @NotNull InliningContext parent,
076                @NotNull FieldRemapper nodeRemapper,
077                boolean isSameModule,
078                @NotNull String errorPrefix
079        ) {
080            this.node = node;
081            this.parameters = parameters;
082            this.inliningContext = parent;
083            this.nodeRemapper = nodeRemapper;
084            this.isSameModule = isSameModule;
085            this.errorPrefix = errorPrefix;
086            this.typeMapper = parent.state.getTypeMapper();
087            this.result = InlineResult.create();
088        }
089    
090    
091        public InlineResult doInline(MethodVisitor adapter, LocalVarRemapper remapper) {
092            return doInline(adapter, remapper, true);
093        }
094    
095        public InlineResult doInline(
096                MethodVisitor adapter,
097                LocalVarRemapper remapper,
098                boolean remapReturn
099        ) {
100            //analyze body
101            MethodNode transformedNode = markPlacesForInlineAndRemoveInlinable(node);
102    
103            transformedNode = doInline(transformedNode);
104            removeClosureAssertions(transformedNode);
105            transformedNode.instructions.resetLabels();
106    
107            Label end = new Label();
108            RemapVisitor visitor = new RemapVisitor(adapter, end, remapper, remapReturn, nodeRemapper);
109            try {
110                transformedNode.accept(visitor);
111            }
112            catch (Exception e) {
113                throw wrapException(e, transformedNode, "couldn't inline method call");
114            }
115    
116            visitor.visitLabel(end);
117    
118            return result;
119        }
120    
121        private MethodNode doInline(MethodNode node) {
122    
123            final Deque<InvokeCall> currentInvokes = new LinkedList<InvokeCall>(invokeCalls);
124    
125            MethodNode resultNode = new MethodNode(node.access, node.name, node.desc, node.signature, null);
126    
127            final Iterator<ConstructorInvocation> iterator = constructorInvocations.iterator();
128    
129            RemappingMethodAdapter remappingMethodAdapter = new RemappingMethodAdapter(resultNode.access, resultNode.desc, resultNode,
130                                                                                       new TypeRemapper(currentTypeMapping));
131    
132            InlineAdapter inliner = new InlineAdapter(remappingMethodAdapter, parameters.totalSize()) {
133    
134                private ConstructorInvocation invocation;
135                @Override
136                public void anew(Type type) {
137                    if (isLambdaConstructorCall(type.getInternalName(), "<init>")) {
138                        invocation = iterator.next();
139    
140                        if (invocation.shouldRegenerate()) {
141                            //TODO: need poping of type but what to do with local funs???
142                            Type newLambdaType = Type.getObjectType(inliningContext.nameGenerator.genLambdaClassName());
143                            currentTypeMapping.put(invocation.getOwnerInternalName(), newLambdaType.getInternalName());
144                            AnonymousObjectTransformer transformer =
145                                    new AnonymousObjectTransformer(invocation.getOwnerInternalName(),
146                                                                   inliningContext
147                                                                           .subInlineWithClassRegeneration(
148                                                                                   inliningContext.nameGenerator,
149                                                                                   currentTypeMapping,
150                                                                                   invocation),
151                                                                   isSameModule, newLambdaType
152                                    );
153    
154                            InlineResult transformResult = transformer.doTransform(invocation, nodeRemapper);
155                            result.addAllClassesToRemove(transformResult);
156    
157                            if (inliningContext.isInliningLambda) {
158                                //this class is transformed and original not used so we should remove original one after inlining
159                                result.addClassToRemove(invocation.getOwnerInternalName());
160                            }
161                        }
162                    }
163    
164                    //in case of regenerated invocation type would be remapped to new one via remappingMethodAdapter
165                    super.anew(type);
166                }
167    
168                @Override
169                public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
170                    if (/*INLINE_RUNTIME.equals(owner) &&*/ isInvokeOnLambda(owner, name)) { //TODO add method
171                        assert !currentInvokes.isEmpty();
172                        InvokeCall invokeCall = currentInvokes.remove();
173                        LambdaInfo info = invokeCall.lambdaInfo;
174    
175                        if (info == null) {
176                            //noninlinable lambda
177                            super.visitMethodInsn(opcode, owner, name, desc, itf);
178                            return;
179                        }
180    
181                        int valueParamShift = getNextLocalIndex();//NB: don't inline cause it changes
182                        putStackValuesIntoLocals(info.getParamsWithoutCapturedValOrVar(), valueParamShift, this, desc);
183    
184                        Parameters lambdaParameters = info.addAllParameters();
185    
186                        InlinedLambdaRemapper newCapturedRemapper =
187                                new InlinedLambdaRemapper(info.getLambdaClassType().getInternalName(), nodeRemapper, lambdaParameters);
188    
189                        setInlining(true);
190                        MethodInliner inliner = new MethodInliner(info.getNode(), lambdaParameters,
191                                                                  inliningContext.subInlineLambda(info),
192                                                                  newCapturedRemapper, true /*cause all calls in same module as lambda*/,
193                                                                  "Lambda inlining " + info.getLambdaClassType().getInternalName());
194    
195                        LocalVarRemapper remapper = new LocalVarRemapper(lambdaParameters, valueParamShift);
196                        InlineResult lambdaResult = inliner.doInline(this.mv, remapper);//TODO add skipped this and receiver
197                        result.addAllClassesToRemove(lambdaResult);
198    
199                        //return value boxing/unboxing
200                        Method bridge = typeMapper.mapSignature(ClosureCodegen.getInvokeFunction(info.getFunctionDescriptor())).getAsmMethod();
201                        Method delegate = typeMapper.mapSignature(info.getFunctionDescriptor()).getAsmMethod();
202                        StackValue.onStack(delegate.getReturnType()).put(bridge.getReturnType(), this);
203                        setInlining(false);
204                    }
205                    else if (isLambdaConstructorCall(owner, name)) { //TODO add method
206                        assert invocation != null : "<init> call not corresponds to new call" + owner + " " + name;
207                        if (invocation.shouldRegenerate()) {
208                            //put additional captured parameters on stack
209                            for (CapturedParamInfo capturedParamInfo : invocation.getAllRecapturedParameters()) {
210                                visitFieldInsn(Opcodes.GETSTATIC, capturedParamInfo.getContainingLambdaName(), "$$$" + capturedParamInfo.getOriginalFieldName(), capturedParamInfo.getType().getDescriptor());
211                            }
212                            super.visitMethodInsn(opcode, invocation.getNewLambdaType().getInternalName(), name, invocation.getNewConstructorDescriptor(), itf);
213                            invocation = null;
214                        } else {
215                            super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc, itf);
216                        }
217                    }
218                    else {
219                        super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc, itf);
220                    }
221                }
222    
223            };
224    
225            node.accept(inliner);
226    
227            return resultNode;
228        }
229    
230        @NotNull
231        public static CapturedParamInfo findCapturedField(FieldInsnNode node, FieldRemapper fieldRemapper) {
232            assert node.name.startsWith("$$$") : "Captured field template should start with $$$ prefix";
233            FieldInsnNode fin = new FieldInsnNode(node.getOpcode(), node.owner, node.name.substring(3), node.desc);
234            CapturedParamInfo field = fieldRemapper.findField(fin);
235            if (field == null) {
236                throw new IllegalStateException("Couldn't find captured field " + node.owner + "." + node.name + " in " + fieldRemapper.getLambdaInternalName());
237            }
238            return field;
239        }
240    
241        @NotNull
242        public MethodNode prepareNode(@NotNull MethodNode node) {
243            final int capturedParamsSize = parameters.getCaptured().size();
244            final int realParametersSize = parameters.getReal().size();
245            Type[] types = Type.getArgumentTypes(node.desc);
246            Type returnType = Type.getReturnType(node.desc);
247    
248            ArrayList<Type> capturedTypes = parameters.getCapturedTypes();
249            Type[] allTypes = ArrayUtil.mergeArrays(types, capturedTypes.toArray(new Type[capturedTypes.size()]));
250    
251            node.instructions.resetLabels();
252            MethodNode transformedNode = new MethodNode(InlineCodegenUtil.API, node.access, node.name, Type.getMethodDescriptor(returnType, allTypes), node.signature, null) {
253    
254                private final boolean isInliningLambda = nodeRemapper.isInsideInliningLambda();
255    
256                private int getNewIndex(int var) {
257                    return var + (var < realParametersSize ? 0 : capturedParamsSize);
258                }
259    
260                @Override
261                public void visitVarInsn(int opcode, int var) {
262                    super.visitVarInsn(opcode, getNewIndex(var));
263                }
264    
265                @Override
266                public void visitIincInsn(int var, int increment) {
267                    super.visitIincInsn(getNewIndex(var), increment);
268                }
269    
270                @Override
271                public void visitMaxs(int maxStack, int maxLocals) {
272                    super.visitMaxs(maxStack, maxLocals + capturedParamsSize);
273                }
274    
275                @Override
276                public void visitLineNumber(int line, Label start) {
277                    if(isInliningLambda) {
278                        super.visitLineNumber(line, start);
279                    }
280                }
281    
282                @Override
283                public void visitLocalVariable(
284                        String name, String desc, String signature, Label start, Label end, int index
285                ) {
286                    if (isInliningLambda) {
287                        super.visitLocalVariable(name, desc, signature, start, end, getNewIndex(index));
288                    }
289                }
290            };
291    
292            node.accept(transformedNode);
293    
294            transformCaptured(transformedNode);
295    
296            return transformedNode;
297        }
298    
299        @NotNull
300        protected MethodNode markPlacesForInlineAndRemoveInlinable(@NotNull MethodNode node) {
301            node = prepareNode(node);
302    
303            Analyzer<SourceValue> analyzer = new Analyzer<SourceValue>(new SourceInterpreter());
304            Frame<SourceValue>[] sources;
305            try {
306                sources = analyzer.analyze("fake", node);
307            }
308            catch (AnalyzerException e) {
309                throw wrapException(e, node, "couldn't inline method call");
310            }
311    
312            AbstractInsnNode cur = node.instructions.getFirst();
313            int index = 0;
314            Set<LabelNode> deadLabels = new HashSet<LabelNode>();
315    
316            while (cur != null) {
317                Frame<SourceValue> frame = sources[index];
318    
319                if (frame != null) {
320                    if (cur.getType() == AbstractInsnNode.METHOD_INSN) {
321                        MethodInsnNode methodInsnNode = (MethodInsnNode) cur;
322                        String owner = methodInsnNode.owner;
323                        String desc = methodInsnNode.desc;
324                        String name = methodInsnNode.name;
325                        //TODO check closure
326                        int paramLength = Type.getArgumentTypes(desc).length + 1;//non static
327                        if (isInvokeOnLambda(owner, name) /*&& methodInsnNode.owner.equals(INLINE_RUNTIME)*/) {
328                            SourceValue sourceValue = frame.getStack(frame.getStackSize() - paramLength);
329    
330                            LambdaInfo lambdaInfo = null;
331                            int varIndex = -1;
332    
333                            if (sourceValue.insns.size() == 1) {
334                                AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
335    
336                                lambdaInfo = getLambdaIfExists(insnNode);
337                                if (lambdaInfo != null) {
338                                    //remove inlinable access
339                                    node.instructions.remove(insnNode);
340                                }
341                            }
342    
343                            invokeCalls.add(new InvokeCall(varIndex, lambdaInfo));
344                        }
345                        else if (isLambdaConstructorCall(owner, name)) {
346                            Map<Integer, LambdaInfo> lambdaMapping = new HashMap<Integer, LambdaInfo>();
347                            int paramStart = frame.getStackSize() - paramLength;
348    
349                            for (int i = 0; i < paramLength; i++) {
350                                SourceValue sourceValue = frame.getStack(paramStart + i);
351                                if (sourceValue.insns.size() == 1) {
352                                    AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
353                                    LambdaInfo lambdaInfo = getLambdaIfExists(insnNode);
354                                    if (lambdaInfo != null) {
355                                        lambdaMapping.put(i, lambdaInfo);
356                                        node.instructions.remove(insnNode);
357                                    }
358                                }
359                            }
360    
361                            constructorInvocations.add(new ConstructorInvocation(owner, desc, lambdaMapping, isSameModule, inliningContext.classRegeneration));
362                        }
363                    }
364                }
365    
366                AbstractInsnNode prevNode = cur;
367                cur = cur.getNext();
368                index++;
369    
370                //given frame is <tt>null</tt> if and only if the corresponding instruction cannot be reached (dead code).
371                if (frame == null) {
372                    //clean dead code otherwise there is problems in unreachable finally block, don't touch label it cause try/catch/finally problems
373                    if (prevNode.getType() == AbstractInsnNode.LABEL) {
374                        deadLabels.add((LabelNode) prevNode);
375                    } else {
376                        node.instructions.remove(prevNode);
377                    }
378                }
379            }
380    
381            //clean dead try/catch blocks
382            List<TryCatchBlockNode> blocks = node.tryCatchBlocks;
383            for (Iterator<TryCatchBlockNode> iterator = blocks.iterator(); iterator.hasNext(); ) {
384                TryCatchBlockNode block = iterator.next();
385                if (deadLabels.contains(block.start) && deadLabels.contains(block.end)) {
386                    iterator.remove();
387                }
388            }
389    
390            return node;
391        }
392    
393        public LambdaInfo getLambdaIfExists(AbstractInsnNode insnNode) {
394            if (insnNode.getOpcode() == Opcodes.ALOAD) {
395                int varIndex = ((VarInsnNode) insnNode).var;
396                if (varIndex < parameters.totalSize()) {
397                    return parameters.get(varIndex).getLambda();
398                }
399            }
400            else if (insnNode instanceof FieldInsnNode) {
401                FieldInsnNode fieldInsnNode = (FieldInsnNode) insnNode;
402                if (fieldInsnNode.name.startsWith("$$$")) {
403                    return findCapturedField(fieldInsnNode, nodeRemapper).getLambda();
404                }
405            }
406    
407            return null;
408        }
409    
410        private static void removeClosureAssertions(MethodNode node) {
411            AbstractInsnNode cur = node.instructions.getFirst();
412            while (cur != null && cur.getNext() != null) {
413                AbstractInsnNode next = cur.getNext();
414                if (next.getType() == AbstractInsnNode.METHOD_INSN) {
415                    MethodInsnNode methodInsnNode = (MethodInsnNode) next;
416                    if (methodInsnNode.name.equals("checkParameterIsNotNull") && methodInsnNode.owner.equals("kotlin/jvm/internal/Intrinsics")) {
417                        AbstractInsnNode prev = cur.getPrevious();
418    
419                        assert cur.getOpcode() == Opcodes.LDC : "checkParameterIsNotNull should go after LDC but " + cur;
420                        assert prev.getOpcode() == Opcodes.ALOAD : "checkParameterIsNotNull should be invoked on local var but " + prev;
421    
422                        node.instructions.remove(prev);
423                        node.instructions.remove(cur);
424                        cur = next.getNext();
425                        node.instructions.remove(next);
426                        next = cur;
427                    }
428                }
429                cur = next;
430            }
431        }
432    
433        private void transformCaptured(@NotNull MethodNode node) {
434            if (nodeRemapper.isRoot()) {
435                return;
436            }
437    
438            //Fold all captured variable chain - ALOAD 0 ALOAD this$0 GETFIELD $captured - to GETFIELD $$$$captured
439            //On future decoding this field could be inline or unfolded in another field access chain (it can differ in some missed this$0)
440            AbstractInsnNode cur = node.instructions.getFirst();
441            while (cur != null) {
442                if (cur instanceof VarInsnNode && cur.getOpcode() == Opcodes.ALOAD) {
443                    if (((VarInsnNode) cur).var == 0) {
444                        List<AbstractInsnNode> accessChain = getCapturedFieldAccessChain((VarInsnNode) cur);
445                        AbstractInsnNode insnNode = nodeRemapper.foldFieldAccessChainIfNeeded(accessChain, node);
446                        if (insnNode != null) {
447                            cur = insnNode;
448                        }
449                    }
450                }
451                cur = cur.getNext();
452            }
453        }
454    
455        @NotNull
456        public static List<AbstractInsnNode> getCapturedFieldAccessChain(@NotNull VarInsnNode aload0) {
457            List<AbstractInsnNode> fieldAccessChain = new ArrayList<AbstractInsnNode>();
458            fieldAccessChain.add(aload0);
459            AbstractInsnNode next = aload0.getNext();
460            while (next != null && next instanceof FieldInsnNode || next instanceof LabelNode) {
461                if (next instanceof LabelNode) {
462                    next = next.getNext();
463                    continue; //it will be delete on transformation
464                }
465                fieldAccessChain.add(next);
466                if ("this$0".equals(((FieldInsnNode) next).name)) {
467                    next = next.getNext();
468                }
469                else {
470                    break;
471                }
472            }
473    
474            return fieldAccessChain;
475        }
476    
477        public static void putStackValuesIntoLocals(List<Type> directOrder, int shift, InstructionAdapter iv, String descriptor) {
478            Type[] actualParams = Type.getArgumentTypes(descriptor);
479            assert actualParams.length == directOrder.size() : "Number of expected and actual params should be equals!";
480    
481            int size = 0;
482            for (Type next : directOrder) {
483                size += next.getSize();
484            }
485    
486            shift += size;
487            int index = directOrder.size();
488    
489            for (Type next : Lists.reverse(directOrder)) {
490                shift -= next.getSize();
491                Type typeOnStack = actualParams[--index];
492                if (!typeOnStack.equals(next)) {
493                    StackValue.onStack(typeOnStack).put(next, iv);
494                }
495                iv.store(shift, next);
496            }
497        }
498    
499        //TODO: check annotation on class - it's package part
500        //TODO: check it's external module
501        //TODO?: assert method exists in facade?
502        public String changeOwnerForExternalPackage(String type, int opcode) {
503            if (isSameModule || (opcode & Opcodes.INVOKESTATIC) == 0) {
504                return type;
505            }
506    
507            int i = type.indexOf('-');
508            if (i >= 0) {
509                return type.substring(0, i);
510            }
511            return type;
512        }
513    
514    
515        public RuntimeException wrapException(@NotNull Exception originalException, @NotNull MethodNode node, @NotNull String errorSuffix) {
516            if (originalException instanceof InlineException) {
517                return new InlineException(errorPrefix + ": " + errorSuffix, originalException);
518            } else {
519                return new InlineException(errorPrefix + ": " + errorSuffix + "\ncause: " +
520                                           InlineCodegen.getNodeText(node), originalException);
521            }
522        }
523    }