001    package org.jetbrains.jet.codegen.inline;
002    
003    import com.google.common.collect.Lists;
004    import com.intellij.util.ArrayUtil;
005    import org.jetbrains.annotations.NotNull;
006    import org.jetbrains.asm4.Label;
007    import org.jetbrains.asm4.MethodVisitor;
008    import org.jetbrains.asm4.Opcodes;
009    import org.jetbrains.asm4.Type;
010    import org.jetbrains.asm4.commons.InstructionAdapter;
011    import org.jetbrains.asm4.commons.Method;
012    import org.jetbrains.asm4.commons.RemappingMethodAdapter;
013    import org.jetbrains.asm4.tree.*;
014    import org.jetbrains.asm4.tree.analysis.*;
015    import org.jetbrains.jet.codegen.ClosureCodegen;
016    import org.jetbrains.jet.codegen.StackValue;
017    import org.jetbrains.jet.codegen.state.JetTypeMapper;
018    
019    import java.util.*;
020    
021    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isInvokeOnLambda;
022    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isLambdaConstructorCall;
023    
024    public class MethodInliner {
025    
026        private final MethodNode node;
027    
028        private final Parameters parameters;
029    
030        private final InliningContext inliningContext;
031    
032        private final FieldRemapper nodeRemapper;
033    
034        private final boolean isSameModule;
035    
036        private final String errorPrefix;
037    
038        private final JetTypeMapper typeMapper;
039    
040        private final List<InvokeCall> invokeCalls = new ArrayList<InvokeCall>();
041    
042        //keeps order
043        private final List<ConstructorInvocation> constructorInvocations = new ArrayList<ConstructorInvocation>();
044        //current state
045        private final Map<String, String> currentTypeMapping = new HashMap<String, String>();
046    
047        private final InlineResult result;
048    
049        /*
050         *
051         * @param node
052         * @param parameters
053         * @param inliningContext
054         * @param lambdaType - in case on lambda 'invoke' inlining
055         */
056        public MethodInliner(
057                @NotNull MethodNode node,
058                @NotNull Parameters parameters,
059                @NotNull InliningContext parent,
060                @NotNull FieldRemapper nodeRemapper,
061                boolean isSameModule,
062                @NotNull String errorPrefix
063        ) {
064            this.node = node;
065            this.parameters = parameters;
066            this.inliningContext = parent;
067            this.nodeRemapper = nodeRemapper;
068            this.isSameModule = isSameModule;
069            this.errorPrefix = errorPrefix;
070            this.typeMapper = parent.state.getTypeMapper();
071            this.result = InlineResult.create();
072        }
073    
074    
075        public InlineResult doInline(MethodVisitor adapter, LocalVarRemapper remapper) {
076            return doInline(adapter, remapper, true);
077        }
078    
079        public InlineResult doInline(
080                MethodVisitor adapter,
081                LocalVarRemapper remapper,
082                boolean remapReturn
083        ) {
084            //analyze body
085            MethodNode transformedNode = markPlacesForInlineAndRemoveInlinable(node);
086    
087            transformedNode = doInline(transformedNode);
088            removeClosureAssertions(transformedNode);
089            transformedNode.instructions.resetLabels();
090    
091            Label end = new Label();
092            RemapVisitor visitor = new RemapVisitor(adapter, end, remapper, remapReturn, nodeRemapper);
093            try {
094                transformedNode.accept(visitor);
095            }
096            catch (Exception e) {
097                throw wrapException(e, transformedNode, "couldn't inline method call");
098            }
099    
100            visitor.visitLabel(end);
101    
102            return result;
103        }
104    
105        private MethodNode doInline(MethodNode node) {
106    
107            final Deque<InvokeCall> currentInvokes = new LinkedList<InvokeCall>(invokeCalls);
108    
109            MethodNode resultNode = new MethodNode(node.access, node.name, node.desc, node.signature, null);
110    
111            final Iterator<ConstructorInvocation> iterator = constructorInvocations.iterator();
112    
113            RemappingMethodAdapter remappingMethodAdapter = new RemappingMethodAdapter(resultNode.access, resultNode.desc, resultNode,
114                                                                                       new TypeRemapper(currentTypeMapping));
115    
116            InlineAdapter inliner = new InlineAdapter(remappingMethodAdapter, parameters.totalSize()) {
117    
118                private ConstructorInvocation invocation;
119                @Override
120                public void anew(Type type) {
121                    if (isLambdaConstructorCall(type.getInternalName(), "<init>")) {
122                        invocation = iterator.next();
123    
124                        if (invocation.shouldRegenerate()) {
125                            //TODO: need poping of type but what to do with local funs???
126                            Type newLambdaType = Type.getObjectType(inliningContext.nameGenerator.genLambdaClassName());
127                            currentTypeMapping.put(invocation.getOwnerInternalName(), newLambdaType.getInternalName());
128                            LambdaTransformer transformer = new LambdaTransformer(invocation.getOwnerInternalName(),
129                                                                                  inliningContext.subInline(inliningContext.nameGenerator, currentTypeMapping).classRegeneration(),
130                                                                                  isSameModule, newLambdaType);
131    
132                            InlineResult transformResult = transformer.doTransform(invocation, nodeRemapper);
133                            result.addAllClassesToRemove(transformResult);
134    
135                            if (inliningContext.isInliningLambda) {
136                                //this class is transformed and original not used so we should remove original one after inlining
137                                result.addClassToRemove(invocation.getOwnerInternalName());
138                            }
139                        }
140                    }
141    
142                    //in case of regenerated invocation type would be remapped to new one via remappingMethodAdapter
143                    super.anew(type);
144                }
145    
146                @Override
147                public void visitMethodInsn(int opcode, String owner, String name, String desc) {
148                    if (/*INLINE_RUNTIME.equals(owner) &&*/ isInvokeOnLambda(owner, name)) { //TODO add method
149                        assert !currentInvokes.isEmpty();
150                        InvokeCall invokeCall = currentInvokes.remove();
151                        LambdaInfo info = invokeCall.lambdaInfo;
152    
153                        if (info == null) {
154                            //noninlinable lambda
155                            super.visitMethodInsn(opcode, owner, name, desc);
156                            return;
157                        }
158    
159                        int valueParamShift = getNextLocalIndex();//NB: don't inline cause it changes
160                        putStackValuesIntoLocals(info.getParamsWithoutCapturedValOrVar(), valueParamShift, this, desc);
161    
162                        Parameters lambdaParameters = info.addAllParameters();
163    
164                        InlinedLambdaRemapper newCapturedRemapper =
165                                new InlinedLambdaRemapper(info.getLambdaClassType().getInternalName(), nodeRemapper, lambdaParameters);
166    
167                        setInlining(true);
168                        MethodInliner inliner = new MethodInliner(info.getNode(), lambdaParameters,
169                                                                  inliningContext.subInlineLambda(info),
170                                                                  newCapturedRemapper, true /*cause all calls in same module as lambda*/,
171                                                                  "Lambda inlining " + info.getLambdaClassType().getInternalName());
172    
173                        LocalVarRemapper remapper = new LocalVarRemapper(lambdaParameters, valueParamShift);
174                        InlineResult lambdaResult = inliner.doInline(this.mv, remapper);//TODO add skipped this and receiver
175                        result.addAllClassesToRemove(lambdaResult);
176    
177                        //return value boxing/unboxing
178                        Method bridge = typeMapper.mapSignature(ClosureCodegen.getInvokeFunction(info.getFunctionDescriptor())).getAsmMethod();
179                        Method delegate = typeMapper.mapSignature(info.getFunctionDescriptor()).getAsmMethod();
180                        StackValue.onStack(delegate.getReturnType()).put(bridge.getReturnType(), this);
181                        setInlining(false);
182                    }
183                    else if (isLambdaConstructorCall(owner, name)) { //TODO add method
184                        assert invocation != null : "<init> call not corresponds to new call" + owner + " " + name;
185                        if (invocation.shouldRegenerate()) {
186                            //put additional captured parameters on stack
187                            for (CapturedParamInfo capturedParamInfo : invocation.getAllRecapturedParameters()) {
188                                visitFieldInsn(Opcodes.GETSTATIC, capturedParamInfo.getContainingLambdaName(), "$$$" + capturedParamInfo.getOriginalFieldName(), capturedParamInfo.getType().getDescriptor());
189                            }
190                            super.visitMethodInsn(opcode, invocation.getNewLambdaType().getInternalName(), name, invocation.getNewConstructorDescriptor());
191                            invocation = null;
192                        } else {
193                            super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc);
194                        }
195                    }
196                    else {
197                        super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc);
198                    }
199                }
200    
201            };
202    
203            node.accept(inliner);
204    
205            return resultNode;
206        }
207    
208        @NotNull
209        public static CapturedParamInfo findCapturedField(FieldInsnNode node, FieldRemapper fieldRemapper) {
210            assert node.name.startsWith("$$$") : "Captured field template should start with $$$ prefix";
211            FieldInsnNode fin = new FieldInsnNode(node.getOpcode(), node.owner, node.name.substring(3), node.desc);
212            CapturedParamInfo field = fieldRemapper.findField(fin);
213            if (field == null) {
214                throw new IllegalStateException("Couldn't find captured field " + node.owner + "." + node.name + " in " + fieldRemapper.getLambdaInternalName());
215            }
216            return field;
217        }
218    
219        @NotNull
220        public MethodNode prepareNode(@NotNull MethodNode node) {
221            final int capturedParamsSize = parameters.getCaptured().size();
222            final int realParametersSize = parameters.getReal().size();
223            Type[] types = Type.getArgumentTypes(node.desc);
224            Type returnType = Type.getReturnType(node.desc);
225    
226            ArrayList<Type> capturedTypes = parameters.getCapturedTypes();
227            Type[] allTypes = ArrayUtil.mergeArrays(types, capturedTypes.toArray(new Type[capturedTypes.size()]));
228    
229            node.instructions.resetLabels();
230            MethodNode transformedNode = new MethodNode(node.access, node.name, Type.getMethodDescriptor(returnType, allTypes), node.signature, null) {
231    
232                private final boolean isInliningLambda = nodeRemapper.isInsideInliningLambda();
233    
234                private int getNewIndex(int var) {
235                    return var + (var < realParametersSize ? 0 : capturedParamsSize);
236                }
237    
238                @Override
239                public void visitVarInsn(int opcode, int var) {
240                    super.visitVarInsn(opcode, getNewIndex(var));
241                }
242    
243                @Override
244                public void visitIincInsn(int var, int increment) {
245                    super.visitIincInsn(getNewIndex(var), increment);
246                }
247    
248                @Override
249                public void visitMaxs(int maxStack, int maxLocals) {
250                    super.visitMaxs(maxStack, maxLocals + capturedParamsSize);
251                }
252    
253                @Override
254                public void visitLineNumber(int line, Label start) {
255                    if(isInliningLambda) {
256                        super.visitLineNumber(line, start);
257                    }
258                }
259    
260                @Override
261                public void visitLocalVariable(
262                        String name, String desc, String signature, Label start, Label end, int index
263                ) {
264                    if (isInliningLambda) {
265                        super.visitLocalVariable(name, desc, signature, start, end, getNewIndex(index));
266                    }
267                }
268            };
269    
270            node.accept(transformedNode);
271    
272            transformCaptured(transformedNode);
273    
274            return transformedNode;
275        }
276    
277        @NotNull
278        protected MethodNode markPlacesForInlineAndRemoveInlinable(@NotNull MethodNode node) {
279            node = prepareNode(node);
280    
281            Analyzer<SourceValue> analyzer = new Analyzer<SourceValue>(new SourceInterpreter());
282            Frame<SourceValue>[] sources;
283            try {
284                sources = analyzer.analyze("fake", node);
285            }
286            catch (AnalyzerException e) {
287                throw wrapException(e, node, "couldn't inline method call");
288            }
289    
290            AbstractInsnNode cur = node.instructions.getFirst();
291            int index = 0;
292            Set<LabelNode> deadLabels = new HashSet<LabelNode>();
293    
294            while (cur != null) {
295                Frame<SourceValue> frame = sources[index];
296    
297                if (frame != null) {
298                    if (cur.getType() == AbstractInsnNode.METHOD_INSN) {
299                        MethodInsnNode methodInsnNode = (MethodInsnNode) cur;
300                        String owner = methodInsnNode.owner;
301                        String desc = methodInsnNode.desc;
302                        String name = methodInsnNode.name;
303                        //TODO check closure
304                        int paramLength = Type.getArgumentTypes(desc).length + 1;//non static
305                        if (isInvokeOnLambda(owner, name) /*&& methodInsnNode.owner.equals(INLINE_RUNTIME)*/) {
306                            SourceValue sourceValue = frame.getStack(frame.getStackSize() - paramLength);
307    
308                            LambdaInfo lambdaInfo = null;
309                            int varIndex = -1;
310    
311                            if (sourceValue.insns.size() == 1) {
312                                AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
313    
314                                lambdaInfo = getLambdaIfExists(insnNode);
315                                if (lambdaInfo != null) {
316                                    //remove inlinable access
317                                    node.instructions.remove(insnNode);
318                                }
319                            }
320    
321                            invokeCalls.add(new InvokeCall(varIndex, lambdaInfo));
322                        }
323                        else if (isLambdaConstructorCall(owner, name)) {
324                            Map<Integer, LambdaInfo> lambdaMapping = new HashMap<Integer, LambdaInfo>();
325                            int paramStart = frame.getStackSize() - paramLength;
326    
327                            for (int i = 0; i < paramLength; i++) {
328                                SourceValue sourceValue = frame.getStack(paramStart + i);
329                                if (sourceValue.insns.size() == 1) {
330                                    AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
331                                    LambdaInfo lambdaInfo = getLambdaIfExists(insnNode);
332                                    if (lambdaInfo != null) {
333                                        lambdaMapping.put(i, lambdaInfo);
334                                        node.instructions.remove(insnNode);
335                                    }
336                                }
337                            }
338    
339                            constructorInvocations.add(new ConstructorInvocation(owner, lambdaMapping, isSameModule, inliningContext.classRegeneration));
340                        }
341                    }
342                }
343    
344                AbstractInsnNode prevNode = cur;
345                cur = cur.getNext();
346                index++;
347    
348                //given frame is <tt>null</tt> if and only if the corresponding instruction cannot be reached (dead code).
349                if (frame == null) {
350                    //clean dead code otherwise there is problems in unreachable finally block, don't touch label it cause try/catch/finally problems
351                    if (prevNode.getType() == AbstractInsnNode.LABEL) {
352                        deadLabels.add((LabelNode) prevNode);
353                    } else {
354                        node.instructions.remove(prevNode);
355                    }
356                }
357            }
358    
359            //clean dead try/catch blocks
360            List<TryCatchBlockNode> blocks = node.tryCatchBlocks;
361            for (Iterator<TryCatchBlockNode> iterator = blocks.iterator(); iterator.hasNext(); ) {
362                TryCatchBlockNode block = iterator.next();
363                if (deadLabels.contains(block.start) && deadLabels.contains(block.end)) {
364                    iterator.remove();
365                }
366            }
367    
368            return node;
369        }
370    
371        public LambdaInfo getLambdaIfExists(AbstractInsnNode insnNode) {
372            if (insnNode.getOpcode() == Opcodes.ALOAD) {
373                int varIndex = ((VarInsnNode) insnNode).var;
374                if (varIndex < parameters.totalSize()) {
375                    return parameters.get(varIndex).getLambda();
376                }
377            }
378            else if (insnNode instanceof FieldInsnNode) {
379                FieldInsnNode fieldInsnNode = (FieldInsnNode) insnNode;
380                if (fieldInsnNode.name.startsWith("$$$")) {
381                    return findCapturedField(fieldInsnNode, nodeRemapper).getLambda();
382                }
383            }
384    
385            return null;
386        }
387    
388        private static void removeClosureAssertions(MethodNode node) {
389            AbstractInsnNode cur = node.instructions.getFirst();
390            while (cur != null && cur.getNext() != null) {
391                AbstractInsnNode next = cur.getNext();
392                if (next.getType() == AbstractInsnNode.METHOD_INSN) {
393                    MethodInsnNode methodInsnNode = (MethodInsnNode) next;
394                    if (methodInsnNode.name.equals("checkParameterIsNotNull") && methodInsnNode.owner.equals("kotlin/jvm/internal/Intrinsics")) {
395                        AbstractInsnNode prev = cur.getPrevious();
396    
397                        assert cur.getOpcode() == Opcodes.LDC : "checkParameterIsNotNull should go after LDC but " + cur;
398                        assert prev.getOpcode() == Opcodes.ALOAD : "checkParameterIsNotNull should be invoked on local var but " + prev;
399    
400                        node.instructions.remove(prev);
401                        node.instructions.remove(cur);
402                        cur = next.getNext();
403                        node.instructions.remove(next);
404                        next = cur;
405                    }
406                }
407                cur = next;
408            }
409        }
410    
411        private void transformCaptured(@NotNull MethodNode node) {
412            if (nodeRemapper.isRoot()) {
413                return;
414            }
415    
416            //Fold all captured variable chain - ALOAD 0 ALOAD this$0 GETFIELD $captured - to GETFIELD $$$$captured
417            //On future decoding this field could be inline or unfolded in another field access chain (it can differ in some missed this$0)
418            AbstractInsnNode cur = node.instructions.getFirst();
419            while (cur != null) {
420                if (cur instanceof VarInsnNode && cur.getOpcode() == Opcodes.ALOAD) {
421                    if (((VarInsnNode) cur).var == 0) {
422                        List<AbstractInsnNode> accessChain = getCapturedFieldAccessChain((VarInsnNode) cur);
423                        AbstractInsnNode insnNode = nodeRemapper.foldFieldAccessChainIfNeeded(accessChain, node);
424                        if (insnNode != null) {
425                            cur = insnNode;
426                        }
427                    }
428                }
429                cur = cur.getNext();
430            }
431        }
432    
433        @NotNull
434        public static List<AbstractInsnNode> getCapturedFieldAccessChain(@NotNull VarInsnNode aload0) {
435            List<AbstractInsnNode> fieldAccessChain = new ArrayList<AbstractInsnNode>();
436            fieldAccessChain.add(aload0);
437            AbstractInsnNode next = aload0.getNext();
438            while (next != null && next instanceof FieldInsnNode || next instanceof LabelNode) {
439                if (next instanceof LabelNode) {
440                    next = next.getNext();
441                    continue; //it will be delete on transformation
442                }
443                fieldAccessChain.add(next);
444                if ("this$0".equals(((FieldInsnNode) next).name)) {
445                    next = next.getNext();
446                }
447                else {
448                    break;
449                }
450            }
451    
452            return fieldAccessChain;
453        }
454    
455        public static void putStackValuesIntoLocals(List<Type> directOrder, int shift, InstructionAdapter iv, String descriptor) {
456            Type[] actualParams = Type.getArgumentTypes(descriptor);
457            assert actualParams.length == directOrder.size() : "Number of expected and actual params should be equals!";
458    
459            int size = 0;
460            for (Type next : directOrder) {
461                size += next.getSize();
462            }
463    
464            shift += size;
465            int index = directOrder.size();
466    
467            for (Type next : Lists.reverse(directOrder)) {
468                shift -= next.getSize();
469                Type typeOnStack = actualParams[--index];
470                if (!typeOnStack.equals(next)) {
471                    StackValue.onStack(typeOnStack).put(next, iv);
472                }
473                iv.store(shift, next);
474            }
475        }
476    
477        //TODO: check annotation on class - it's package part
478        //TODO: check it's external module
479        //TODO?: assert method exists in facade?
480        public String changeOwnerForExternalPackage(String type, int opcode) {
481            if (isSameModule || (opcode & Opcodes.INVOKESTATIC) == 0) {
482                return type;
483            }
484    
485            int i = type.indexOf('-');
486            if (i >= 0) {
487                return type.substring(0, i);
488            }
489            return type;
490        }
491    
492    
493        public RuntimeException wrapException(@NotNull Exception originalException, @NotNull MethodNode node, @NotNull String errorSuffix) {
494            if (originalException instanceof InlineException) {
495                return new InlineException(errorPrefix + ": " + errorSuffix, originalException);
496            } else {
497                return new InlineException(errorPrefix + ": " + errorSuffix + "\ncause: " +
498                                           InlineCodegen.getNodeText(node), originalException);
499            }
500        }
501    }