001    package org.jetbrains.jet.codegen.inline;
002    
003    import com.google.common.collect.Lists;
004    import com.intellij.util.ArrayUtil;
005    import org.jetbrains.annotations.NotNull;
006    import org.jetbrains.asm4.Label;
007    import org.jetbrains.asm4.MethodVisitor;
008    import org.jetbrains.asm4.Opcodes;
009    import org.jetbrains.asm4.Type;
010    import org.jetbrains.asm4.commons.InstructionAdapter;
011    import org.jetbrains.asm4.commons.Method;
012    import org.jetbrains.asm4.commons.RemappingMethodAdapter;
013    import org.jetbrains.asm4.tree.*;
014    import org.jetbrains.asm4.tree.analysis.*;
015    import org.jetbrains.jet.codegen.ClosureCodegen;
016    import org.jetbrains.jet.codegen.StackValue;
017    import org.jetbrains.jet.codegen.state.JetTypeMapper;
018    
019    import java.util.*;
020    
021    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isInvokeOnLambda;
022    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isLambdaConstructorCall;
023    
024    public class MethodInliner {
025    
026        private final MethodNode node;
027    
028        private final Parameters parameters;
029    
030        private final InliningContext inliningContext;
031    
032        private final FieldRemapper nodeRemapper;
033    
034        private final boolean isSameModule;
035    
036        private final String errorPrefix;
037    
038        private final JetTypeMapper typeMapper;
039    
040        private final List<InvokeCall> invokeCalls = new ArrayList<InvokeCall>();
041    
042        //keeps order
043        private final List<ConstructorInvocation> constructorInvocations = new ArrayList<ConstructorInvocation>();
044        //current state
045        private final Map<String, String> currentTypeMapping = new HashMap<String, String>();
046    
047        private final InlineResult result;
048    
049        /*
050         *
051         * @param node
052         * @param parameters
053         * @param inliningContext
054         * @param lambdaType - in case on lambda 'invoke' inlining
055         */
056        public MethodInliner(
057                @NotNull MethodNode node,
058                @NotNull Parameters parameters,
059                @NotNull InliningContext parent,
060                @NotNull FieldRemapper nodeRemapper,
061                boolean isSameModule,
062                @NotNull String errorPrefix
063        ) {
064            this.node = node;
065            this.parameters = parameters;
066            this.inliningContext = parent;
067            this.nodeRemapper = nodeRemapper;
068            this.isSameModule = isSameModule;
069            this.errorPrefix = errorPrefix;
070            this.typeMapper = parent.state.getTypeMapper();
071            this.result = InlineResult.create();
072        }
073    
074    
075        public InlineResult doInline(MethodVisitor adapter, VarRemapper.ParamRemapper remapper, FieldRemapper capturedRemapper) {
076            return doInline(adapter, remapper, capturedRemapper, true);
077        }
078    
079        public InlineResult doInline(
080                MethodVisitor adapter,
081                VarRemapper.ParamRemapper remapper,
082                FieldRemapper capturedRemapper, boolean remapReturn
083        ) {
084            //analyze body
085            MethodNode transformedNode = markPlacesForInlineAndRemoveInlinable(node);
086    
087            transformedNode = doInline(transformedNode, capturedRemapper);
088            removeClosureAssertions(transformedNode);
089            transformedNode.instructions.resetLabels();
090    
091            Label end = new Label();
092            RemapVisitor visitor = new RemapVisitor(adapter, end, remapper, remapReturn, nodeRemapper);
093            try {
094                transformedNode.accept(visitor);
095            }
096            catch (Exception e) {
097                throw wrapException(e, transformedNode, "couldn't inline method call");
098            }
099    
100            visitor.visitLabel(end);
101    
102            return result;
103        }
104    
105        private MethodNode doInline(MethodNode node, final FieldRemapper capturedRemapper) {
106    
107            final Deque<InvokeCall> currentInvokes = new LinkedList<InvokeCall>(invokeCalls);
108    
109            MethodNode resultNode = new MethodNode(node.access, node.name, node.desc, node.signature, null);
110    
111            final Iterator<ConstructorInvocation> iterator = constructorInvocations.iterator();
112    
113            RemappingMethodAdapter remappingMethodAdapter = new RemappingMethodAdapter(resultNode.access, resultNode.desc, resultNode,
114                                                                                       new TypeRemapper(currentTypeMapping));
115    
116            InlineAdapter inliner = new InlineAdapter(remappingMethodAdapter, parameters.totalSize()) {
117    
118                private ConstructorInvocation invocation;
119                @Override
120                public void anew(Type type) {
121                    if (isLambdaConstructorCall(type.getInternalName(), "<init>")) {
122                        invocation = iterator.next();
123    
124                        if (invocation.shouldRegenerate()) {
125                            //TODO: need poping of type but what to do with local funs???
126                            Type newLambdaType = Type.getObjectType(inliningContext.nameGenerator.genLambdaClassName());
127                            currentTypeMapping.put(invocation.getOwnerInternalName(), newLambdaType.getInternalName());
128                            LambdaTransformer transformer = new LambdaTransformer(invocation.getOwnerInternalName(),
129                                                                                  inliningContext.subInline(inliningContext.nameGenerator, currentTypeMapping).classRegeneration(),
130                                                                                  isSameModule, newLambdaType);
131    
132                            InlineResult transformResult = transformer.doTransform(invocation, capturedRemapper);
133                            result.addAllClassesToRemove(transformResult);
134    
135                            if (inliningContext.isInliningLambda) {
136                                //this class is transformed and original not used so we should remove original one after inlining
137                                result.addClassToRemove(invocation.getOwnerInternalName());
138                            }
139                        }
140                    }
141    
142                    //in case of regenerated invocation type would be remapped to new one via remappingMethodAdapter
143                    super.anew(type);
144                }
145    
146                @Override
147                public void visitMethodInsn(int opcode, String owner, String name, String desc) {
148                    if (/*INLINE_RUNTIME.equals(owner) &&*/ isInvokeOnLambda(owner, name)) { //TODO add method
149                        assert !currentInvokes.isEmpty();
150                        InvokeCall invokeCall = currentInvokes.remove();
151                        LambdaInfo info = invokeCall.lambdaInfo;
152    
153                        if (info == null) {
154                            //noninlinable lambda
155                            super.visitMethodInsn(opcode, owner, name, desc);
156                            return;
157                        }
158    
159                        int valueParamShift = getNextLocalIndex();//NB: don't inline cause it changes
160                        putStackValuesIntoLocals(info.getParamsWithoutCapturedValOrVar(), valueParamShift, this, desc);
161    
162                        Parameters lambdaParameters = info.addAllParameters(capturedRemapper);
163    
164                        InlinedLambdaRemapper newCapturedRemapper =
165                                new InlinedLambdaRemapper(info.getLambdaClassType().getInternalName(), capturedRemapper, lambdaParameters);
166    
167                        FieldRemapper fieldRemapper =
168                                new FieldRemapper(info.getLambdaClassType().getInternalName(), capturedRemapper, lambdaParameters);
169    
170                        setInlining(true);
171                        MethodInliner inliner = new MethodInliner(info.getNode(), lambdaParameters,
172                                                                  inliningContext.subInlineLambda(info),
173                                                                  fieldRemapper, true /*cause all calls in same module as lambda*/,
174                                                                  "Lambda inlining " + info.getLambdaClassType().getInternalName());
175    
176                        VarRemapper.ParamRemapper remapper = new VarRemapper.ParamRemapper(lambdaParameters, valueParamShift);
177                        InlineResult lambdaResult = inliner.doInline(this.mv, remapper, newCapturedRemapper);//TODO add skipped this and receiver
178                        result.addAllClassesToRemove(lambdaResult);
179    
180                        //return value boxing/unboxing
181                        Method bridge = typeMapper.mapSignature(ClosureCodegen.getInvokeFunction(info.getFunctionDescriptor())).getAsmMethod();
182                        Method delegate = typeMapper.mapSignature(info.getFunctionDescriptor()).getAsmMethod();
183                        StackValue.onStack(delegate.getReturnType()).put(bridge.getReturnType(), this);
184                        setInlining(false);
185                    }
186                    else if (isLambdaConstructorCall(owner, name)) { //TODO add method
187                        assert invocation != null : "<init> call not corresponds to new call" + owner + " " + name;
188                        if (invocation.shouldRegenerate()) {
189                            //put additional captured parameters on stack
190                            for (CapturedParamInfo capturedParamInfo : invocation.getAllRecapturedParameters()) {
191                                visitFieldInsn(Opcodes.GETSTATIC, capturedParamInfo.getContainingLambdaName(), "$$$" + capturedParamInfo.getFieldName(), capturedParamInfo.getType().getDescriptor());
192                            }
193                            super.visitMethodInsn(opcode, invocation.getNewLambdaType().getInternalName(), name, invocation.getNewConstructorDescriptor());
194                            invocation = null;
195                        } else {
196                            super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc);
197                        }
198                    }
199                    else {
200                        super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc);
201                    }
202                }
203    
204            };
205    
206            node.accept(inliner);
207    
208            return resultNode;
209        }
210    
211        @NotNull
212        public static CapturedParamInfo findCapturedField(FieldInsnNode node, FieldRemapper fieldRemapper) {
213            assert node.name.startsWith("$$$") : "Captured field template should start with $$$ prefix";
214            FieldInsnNode fin = new FieldInsnNode(node.getOpcode(), node.owner, node.name.substring(3), node.desc);
215            CapturedParamInfo field = fieldRemapper.findField(fin);
216            if (field == null) {
217                throw new IllegalStateException("Couldn't find captured field " + node.owner + "." + node.name + " in " + fieldRemapper.getLambdaInternalName());
218            }
219            return field;
220        }
221    
222        @NotNull
223        public MethodNode prepareNode(@NotNull MethodNode node) {
224            final int capturedParamsSize = parameters.getCaptured().size();
225            final int realParametersSize = parameters.getReal().size();
226            Type[] types = Type.getArgumentTypes(node.desc);
227            Type returnType = Type.getReturnType(node.desc);
228    
229            ArrayList<Type> capturedTypes = parameters.getCapturedTypes();
230            Type[] allTypes = ArrayUtil.mergeArrays(types, capturedTypes.toArray(new Type[capturedTypes.size()]));
231    
232            node.instructions.resetLabels();
233            MethodNode transformedNode = new MethodNode(node.access, node.name, Type.getMethodDescriptor(returnType, allTypes), node.signature, null) {
234    
235                @Override
236                public void visitVarInsn(int opcode, int var) {
237                    int newIndex;
238                    if (var < realParametersSize) {
239                        newIndex = var;
240                    } else {
241                        newIndex = var + capturedParamsSize;
242                    }
243                    super.visitVarInsn(opcode, newIndex);
244                }
245    
246                @Override
247                public void visitIincInsn(int var, int increment) {
248                    int newIndex;
249                    if (var < realParametersSize) {
250                        newIndex = var;
251                    } else {
252                        newIndex = var + capturedParamsSize;
253                    }
254                    super.visitIincInsn(newIndex, increment);
255                }
256    
257                @Override
258                public void visitMaxs(int maxStack, int maxLocals) {
259                    super.visitMaxs(maxStack, maxLocals + capturedParamsSize);
260                }
261            };
262    
263            node.accept(transformedNode);
264    
265            transformCaptured(transformedNode);
266    
267            return transformedNode;
268        }
269    
270        @NotNull
271        protected MethodNode markPlacesForInlineAndRemoveInlinable(@NotNull MethodNode node) {
272            node = prepareNode(node);
273    
274            Analyzer<SourceValue> analyzer = new Analyzer<SourceValue>(new SourceInterpreter());
275            Frame<SourceValue>[] sources;
276            try {
277                sources = analyzer.analyze("fake", node);
278            }
279            catch (AnalyzerException e) {
280                throw wrapException(e, node, "couldn't inline method call");
281            }
282    
283            AbstractInsnNode cur = node.instructions.getFirst();
284            int index = 0;
285            Set<LabelNode> deadLabels = new HashSet<LabelNode>();
286    
287            while (cur != null) {
288                Frame<SourceValue> frame = sources[index];
289    
290                if (frame != null) {
291                    if (cur.getType() == AbstractInsnNode.METHOD_INSN) {
292                        MethodInsnNode methodInsnNode = (MethodInsnNode) cur;
293                        String owner = methodInsnNode.owner;
294                        String desc = methodInsnNode.desc;
295                        String name = methodInsnNode.name;
296                        //TODO check closure
297                        int paramLength = Type.getArgumentTypes(desc).length + 1;//non static
298                        if (isInvokeOnLambda(owner, name) /*&& methodInsnNode.owner.equals(INLINE_RUNTIME)*/) {
299                            SourceValue sourceValue = frame.getStack(frame.getStackSize() - paramLength);
300    
301                            LambdaInfo lambdaInfo = null;
302                            int varIndex = -1;
303    
304                            if (sourceValue.insns.size() == 1) {
305                                AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
306    
307                                lambdaInfo = getLambdaIfExists(insnNode);
308                                if (lambdaInfo != null) {
309                                    //remove inlinable access
310                                    node.instructions.remove(insnNode);
311                                }
312                            }
313    
314                            invokeCalls.add(new InvokeCall(varIndex, lambdaInfo));
315                        }
316                        else if (isLambdaConstructorCall(owner, name)) {
317                            Map<Integer, LambdaInfo> lambdaMapping = new HashMap<Integer, LambdaInfo>();
318                            int paramStart = frame.getStackSize() - paramLength;
319    
320                            for (int i = 0; i < paramLength; i++) {
321                                SourceValue sourceValue = frame.getStack(paramStart + i);
322                                if (sourceValue.insns.size() == 1) {
323                                    AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
324                                    LambdaInfo lambdaInfo = getLambdaIfExists(insnNode);
325                                    if (lambdaInfo != null) {
326                                        lambdaMapping.put(i, lambdaInfo);
327                                        node.instructions.remove(insnNode);
328                                    }
329                                }
330                            }
331    
332                            constructorInvocations.add(new ConstructorInvocation(owner, lambdaMapping, isSameModule, inliningContext.classRegeneration));
333                        }
334                    }
335                }
336    
337                AbstractInsnNode prevNode = cur;
338                cur = cur.getNext();
339                index++;
340    
341                //given frame is <tt>null</tt> if and only if the corresponding instruction cannot be reached (dead code).
342                if (frame == null) {
343                    //clean dead code otherwise there is problems in unreachable finally block, don't touch label it cause try/catch/finally problems
344                    if (prevNode.getType() == AbstractInsnNode.LABEL) {
345                        deadLabels.add((LabelNode) prevNode);
346                    } else {
347                        node.instructions.remove(prevNode);
348                    }
349                }
350            }
351    
352            //clean dead try/catch blocks
353            List<TryCatchBlockNode> blocks = node.tryCatchBlocks;
354            for (Iterator<TryCatchBlockNode> iterator = blocks.iterator(); iterator.hasNext(); ) {
355                TryCatchBlockNode block = iterator.next();
356                if (deadLabels.contains(block.start) && deadLabels.contains(block.end)) {
357                    iterator.remove();
358                }
359            }
360    
361            return node;
362        }
363    
364        public LambdaInfo getLambdaIfExists(AbstractInsnNode insnNode) {
365            if (insnNode.getOpcode() == Opcodes.ALOAD) {
366                int varIndex = ((VarInsnNode) insnNode).var;
367                if (varIndex < parameters.totalSize()) {
368                    return parameters.get(varIndex).getLambda();
369                }
370            }
371            else if (insnNode instanceof FieldInsnNode) {
372                FieldInsnNode fieldInsnNode = (FieldInsnNode) insnNode;
373                if (fieldInsnNode.name.startsWith("$$$")) {
374                    return findCapturedField(fieldInsnNode, nodeRemapper).getLambda();
375                }
376            }
377    
378            return null;
379        }
380    
381        private static void removeClosureAssertions(MethodNode node) {
382            AbstractInsnNode cur = node.instructions.getFirst();
383            while (cur != null && cur.getNext() != null) {
384                AbstractInsnNode next = cur.getNext();
385                if (next.getType() == AbstractInsnNode.METHOD_INSN) {
386                    MethodInsnNode methodInsnNode = (MethodInsnNode) next;
387                    if (methodInsnNode.name.equals("checkParameterIsNotNull") && methodInsnNode.owner.equals("kotlin/jvm/internal/Intrinsics")) {
388                        AbstractInsnNode prev = cur.getPrevious();
389    
390                        assert cur.getOpcode() == Opcodes.LDC : "checkParameterIsNotNull should go after LDC but " + cur;
391                        assert prev.getOpcode() == Opcodes.ALOAD : "checkParameterIsNotNull should be invoked on local var but " + prev;
392    
393                        node.instructions.remove(prev);
394                        node.instructions.remove(cur);
395                        cur = next.getNext();
396                        node.instructions.remove(next);
397                        next = cur;
398                    }
399                }
400                cur = next;
401            }
402        }
403    
404        private void transformCaptured(@NotNull MethodNode node) {
405            if (nodeRemapper.isRoot()) {
406                return;
407            }
408    
409            //Fold all captured variable chain - ALOAD 0 ALOAD this$0 GETFIELD $captured - to GETFIELD $$$$captured
410            //On future decoding this field could be inline or unfolded in another field access chain (it can differ in some missed this$0)
411            AbstractInsnNode cur = node.instructions.getFirst();
412            while (cur != null) {
413                if (cur instanceof VarInsnNode && cur.getOpcode() == Opcodes.ALOAD) {
414                    if (((VarInsnNode) cur).var == 0) {
415                        List<AbstractInsnNode> accessChain = getCapturedFieldAccessChain((VarInsnNode) cur);
416                        AbstractInsnNode insnNode = nodeRemapper.transformIfNeeded(accessChain, node);
417                        if (insnNode != null) {
418                            cur = insnNode;
419                        }
420                    }
421                }
422                cur = cur.getNext();
423            }
424        }
425    
426        @NotNull
427        public static List<AbstractInsnNode> getCapturedFieldAccessChain(@NotNull VarInsnNode aload0) {
428            List<AbstractInsnNode> fieldAccessChain = new ArrayList<AbstractInsnNode>();
429            fieldAccessChain.add(aload0);
430            AbstractInsnNode next = aload0.getNext();
431            while (next != null && next instanceof FieldInsnNode || next instanceof LabelNode) {
432                if (next instanceof LabelNode) {
433                    next = next.getNext();
434                    continue; //it will be delete on transformation
435                }
436                fieldAccessChain.add(next);
437                if ("this$0".equals(((FieldInsnNode) next).name)) {
438                    next = next.getNext();
439                }
440                else {
441                    break;
442                }
443            }
444    
445            return fieldAccessChain;
446        }
447    
448        public static void putStackValuesIntoLocals(List<Type> directOrder, int shift, InstructionAdapter iv, String descriptor) {
449            Type[] actualParams = Type.getArgumentTypes(descriptor);
450            assert actualParams.length == directOrder.size() : "Number of expected and actual params should be equals!";
451    
452            int size = 0;
453            for (Type next : directOrder) {
454                size += next.getSize();
455            }
456    
457            shift += size;
458            int index = directOrder.size();
459    
460            for (Type next : Lists.reverse(directOrder)) {
461                shift -= next.getSize();
462                Type typeOnStack = actualParams[--index];
463                if (!typeOnStack.equals(next)) {
464                    StackValue.onStack(typeOnStack).put(next, iv);
465                }
466                iv.store(shift, next);
467            }
468        }
469    
470        //TODO: check annotation on class - it's package part
471        //TODO: check it's external module
472        //TODO?: assert method exists in facade?
473        public String changeOwnerForExternalPackage(String type, int opcode) {
474            if (isSameModule || (opcode & Opcodes.INVOKESTATIC) == 0) {
475                return type;
476            }
477    
478            int i = type.indexOf('-');
479            if (i >= 0) {
480                return type.substring(0, i);
481            }
482            return type;
483        }
484    
485    
486        public RuntimeException wrapException(@NotNull Exception originalException, @NotNull MethodNode node, @NotNull String errorSuffix) {
487            if (originalException instanceof InlineException) {
488                return new InlineException(errorPrefix + ": " + errorSuffix, originalException);
489            } else {
490                return new InlineException(errorPrefix + ": " + errorSuffix + "\ncause: " +
491                                           InlineCodegen.getNodeText(node), originalException);
492            }
493        }
494    }