001    /*
002     * Copyright 2010-2014 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.codegen.inline;
018    
019    import com.google.common.collect.Lists;
020    import com.intellij.util.ArrayUtil;
021    import org.jetbrains.annotations.NotNull;
022    import org.jetbrains.jet.codegen.ClosureCodegen;
023    import org.jetbrains.jet.codegen.StackValue;
024    import org.jetbrains.jet.codegen.state.JetTypeMapper;
025    import org.jetbrains.org.objectweb.asm.Label;
026    import org.jetbrains.org.objectweb.asm.MethodVisitor;
027    import org.jetbrains.org.objectweb.asm.Opcodes;
028    import org.jetbrains.org.objectweb.asm.Type;
029    import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter;
030    import org.jetbrains.org.objectweb.asm.commons.Method;
031    import org.jetbrains.org.objectweb.asm.commons.RemappingMethodAdapter;
032    import org.jetbrains.org.objectweb.asm.tree.*;
033    import org.jetbrains.org.objectweb.asm.tree.analysis.*;
034    
035    import java.util.*;
036    
037    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isInvokeOnLambda;
038    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isLambdaConstructorCall;
039    
040    public class MethodInliner {
041    
042        private final MethodNode node;
043    
044        private final Parameters parameters;
045    
046        private final InliningContext inliningContext;
047    
048        private final FieldRemapper nodeRemapper;
049    
050        private final boolean isSameModule;
051    
052        private final String errorPrefix;
053    
054        private final JetTypeMapper typeMapper;
055    
056        private final List<InvokeCall> invokeCalls = new ArrayList<InvokeCall>();
057    
058        //keeps order
059        private final List<ConstructorInvocation> constructorInvocations = new ArrayList<ConstructorInvocation>();
060        //current state
061        private final Map<String, String> currentTypeMapping = new HashMap<String, String>();
062    
063        private final InlineResult result;
064    
065        /*
066         *
067         * @param node
068         * @param parameters
069         * @param inliningContext
070         * @param lambdaType - in case on lambda 'invoke' inlining
071         */
072        public MethodInliner(
073                @NotNull MethodNode node,
074                @NotNull Parameters parameters,
075                @NotNull InliningContext parent,
076                @NotNull FieldRemapper nodeRemapper,
077                boolean isSameModule,
078                @NotNull String errorPrefix
079        ) {
080            this.node = node;
081            this.parameters = parameters;
082            this.inliningContext = parent;
083            this.nodeRemapper = nodeRemapper;
084            this.isSameModule = isSameModule;
085            this.errorPrefix = errorPrefix;
086            this.typeMapper = parent.state.getTypeMapper();
087            this.result = InlineResult.create();
088        }
089    
090    
091        public InlineResult doInline(MethodVisitor adapter, LocalVarRemapper remapper) {
092            return doInline(adapter, remapper, true);
093        }
094    
095        public InlineResult doInline(
096                MethodVisitor adapter,
097                LocalVarRemapper remapper,
098                boolean remapReturn
099        ) {
100            //analyze body
101            MethodNode transformedNode = markPlacesForInlineAndRemoveInlinable(node);
102    
103            transformedNode = doInline(transformedNode);
104            removeClosureAssertions(transformedNode);
105            transformedNode.instructions.resetLabels();
106    
107            Label end = new Label();
108            RemapVisitor visitor = new RemapVisitor(adapter, end, remapper, remapReturn, nodeRemapper);
109            try {
110                transformedNode.accept(visitor);
111            }
112            catch (Exception e) {
113                throw wrapException(e, transformedNode, "couldn't inline method call");
114            }
115    
116            visitor.visitLabel(end);
117    
118            return result;
119        }
120    
121        private MethodNode doInline(MethodNode node) {
122    
123            final Deque<InvokeCall> currentInvokes = new LinkedList<InvokeCall>(invokeCalls);
124    
125            MethodNode resultNode = new MethodNode(node.access, node.name, node.desc, node.signature, null);
126    
127            final Iterator<ConstructorInvocation> iterator = constructorInvocations.iterator();
128    
129            RemappingMethodAdapter remappingMethodAdapter = new RemappingMethodAdapter(resultNode.access, resultNode.desc, resultNode,
130                                                                                       new TypeRemapper(currentTypeMapping));
131    
132            InlineAdapter inliner = new InlineAdapter(remappingMethodAdapter, parameters.totalSize()) {
133    
134                private ConstructorInvocation invocation;
135                @Override
136                public void anew(Type type) {
137                    if (isLambdaConstructorCall(type.getInternalName(), "<init>")) {
138                        invocation = iterator.next();
139    
140                        if (invocation.shouldRegenerate()) {
141                            //TODO: need poping of type but what to do with local funs???
142                            Type newLambdaType = Type.getObjectType(inliningContext.nameGenerator.genLambdaClassName());
143                            currentTypeMapping.put(invocation.getOwnerInternalName(), newLambdaType.getInternalName());
144                            AnonymousObjectTransformer transformer =
145                                    new AnonymousObjectTransformer(invocation.getOwnerInternalName(),
146                                                                   inliningContext
147                                                                           .subInlineWithClassRegeneration(
148                                                                                   inliningContext.nameGenerator,
149                                                                                   currentTypeMapping,
150                                                                                   invocation),
151                                                                   isSameModule, newLambdaType
152                                    );
153    
154                            InlineResult transformResult = transformer.doTransform(invocation, nodeRemapper);
155                            result.addAllClassesToRemove(transformResult);
156    
157                            if (inliningContext.isInliningLambda) {
158                                //this class is transformed and original not used so we should remove original one after inlining
159                                result.addClassToRemove(invocation.getOwnerInternalName());
160                            }
161                        }
162                    }
163    
164                    //in case of regenerated invocation type would be remapped to new one via remappingMethodAdapter
165                    super.anew(type);
166                }
167    
168                @Override
169                public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
170                    if (/*INLINE_RUNTIME.equals(owner) &&*/ isInvokeOnLambda(owner, name)) { //TODO add method
171                        assert !currentInvokes.isEmpty();
172                        InvokeCall invokeCall = currentInvokes.remove();
173                        LambdaInfo info = invokeCall.lambdaInfo;
174    
175                        if (info == null) {
176                            //noninlinable lambda
177                            super.visitMethodInsn(opcode, owner, name, desc, itf);
178                            return;
179                        }
180    
181                        int valueParamShift = getNextLocalIndex();//NB: don't inline cause it changes
182                        putStackValuesIntoLocals(info.getParamsWithoutCapturedValOrVar(), valueParamShift, this, desc);
183    
184                        Parameters lambdaParameters = info.addAllParameters();
185    
186                        InlinedLambdaRemapper newCapturedRemapper =
187                                new InlinedLambdaRemapper(info.getLambdaClassType().getInternalName(), nodeRemapper, lambdaParameters);
188    
189                        setInlining(true);
190                        MethodInliner inliner = new MethodInliner(info.getNode(), lambdaParameters,
191                                                                  inliningContext.subInlineLambda(info),
192                                                                  newCapturedRemapper, true /*cause all calls in same module as lambda*/,
193                                                                  "Lambda inlining " + info.getLambdaClassType().getInternalName());
194    
195                        LocalVarRemapper remapper = new LocalVarRemapper(lambdaParameters, valueParamShift);
196                        InlineResult lambdaResult = inliner.doInline(this.mv, remapper);//TODO add skipped this and receiver
197                        result.addAllClassesToRemove(lambdaResult);
198    
199                        //return value boxing/unboxing
200                        Method bridge =
201                                typeMapper.mapSignature(ClosureCodegen.getErasedInvokeFunction(info.getFunctionDescriptor())).getAsmMethod();
202                        Method delegate = typeMapper.mapSignature(info.getFunctionDescriptor()).getAsmMethod();
203                        StackValue.onStack(delegate.getReturnType()).put(bridge.getReturnType(), this);
204                        setInlining(false);
205                    }
206                    else if (isLambdaConstructorCall(owner, name)) { //TODO add method
207                        assert invocation != null : "<init> call not corresponds to new call" + owner + " " + name;
208                        if (invocation.shouldRegenerate()) {
209                            //put additional captured parameters on stack
210                            for (CapturedParamInfo capturedParamInfo : invocation.getAllRecapturedParameters()) {
211                                visitFieldInsn(Opcodes.GETSTATIC, capturedParamInfo.getContainingLambdaName(), "$$$" + capturedParamInfo.getOriginalFieldName(), capturedParamInfo.getType().getDescriptor());
212                            }
213                            super.visitMethodInsn(opcode, invocation.getNewLambdaType().getInternalName(), name, invocation.getNewConstructorDescriptor(), itf);
214                            invocation = null;
215                        } else {
216                            super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc, itf);
217                        }
218                    }
219                    else {
220                        super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc, itf);
221                    }
222                }
223    
224            };
225    
226            node.accept(inliner);
227    
228            return resultNode;
229        }
230    
231        @NotNull
232        public static CapturedParamInfo findCapturedField(FieldInsnNode node, FieldRemapper fieldRemapper) {
233            assert node.name.startsWith("$$$") : "Captured field template should start with $$$ prefix";
234            FieldInsnNode fin = new FieldInsnNode(node.getOpcode(), node.owner, node.name.substring(3), node.desc);
235            CapturedParamInfo field = fieldRemapper.findField(fin);
236            if (field == null) {
237                throw new IllegalStateException("Couldn't find captured field " + node.owner + "." + node.name + " in " + fieldRemapper.getLambdaInternalName());
238            }
239            return field;
240        }
241    
242        @NotNull
243        public MethodNode prepareNode(@NotNull MethodNode node) {
244            final int capturedParamsSize = parameters.getCaptured().size();
245            final int realParametersSize = parameters.getReal().size();
246            Type[] types = Type.getArgumentTypes(node.desc);
247            Type returnType = Type.getReturnType(node.desc);
248    
249            ArrayList<Type> capturedTypes = parameters.getCapturedTypes();
250            Type[] allTypes = ArrayUtil.mergeArrays(types, capturedTypes.toArray(new Type[capturedTypes.size()]));
251    
252            node.instructions.resetLabels();
253            MethodNode transformedNode = new MethodNode(InlineCodegenUtil.API, node.access, node.name, Type.getMethodDescriptor(returnType, allTypes), node.signature, null) {
254    
255                private final boolean isInliningLambda = nodeRemapper.isInsideInliningLambda();
256    
257                private int getNewIndex(int var) {
258                    return var + (var < realParametersSize ? 0 : capturedParamsSize);
259                }
260    
261                @Override
262                public void visitVarInsn(int opcode, int var) {
263                    super.visitVarInsn(opcode, getNewIndex(var));
264                }
265    
266                @Override
267                public void visitIincInsn(int var, int increment) {
268                    super.visitIincInsn(getNewIndex(var), increment);
269                }
270    
271                @Override
272                public void visitMaxs(int maxStack, int maxLocals) {
273                    super.visitMaxs(maxStack, maxLocals + capturedParamsSize);
274                }
275    
276                @Override
277                public void visitLineNumber(int line, Label start) {
278                    if(isInliningLambda) {
279                        super.visitLineNumber(line, start);
280                    }
281                }
282    
283                @Override
284                public void visitLocalVariable(
285                        String name, String desc, String signature, Label start, Label end, int index
286                ) {
287                    if (isInliningLambda) {
288                        super.visitLocalVariable(name, desc, signature, start, end, getNewIndex(index));
289                    }
290                }
291            };
292    
293            node.accept(transformedNode);
294    
295            transformCaptured(transformedNode);
296    
297            return transformedNode;
298        }
299    
300        @NotNull
301        protected MethodNode markPlacesForInlineAndRemoveInlinable(@NotNull MethodNode node) {
302            node = prepareNode(node);
303    
304            Analyzer<SourceValue> analyzer = new Analyzer<SourceValue>(new SourceInterpreter());
305            Frame<SourceValue>[] sources;
306            try {
307                sources = analyzer.analyze("fake", node);
308            }
309            catch (AnalyzerException e) {
310                throw wrapException(e, node, "couldn't inline method call");
311            }
312    
313            AbstractInsnNode cur = node.instructions.getFirst();
314            int index = 0;
315            Set<LabelNode> deadLabels = new HashSet<LabelNode>();
316    
317            while (cur != null) {
318                Frame<SourceValue> frame = sources[index];
319    
320                if (frame != null) {
321                    if (cur.getType() == AbstractInsnNode.METHOD_INSN) {
322                        MethodInsnNode methodInsnNode = (MethodInsnNode) cur;
323                        String owner = methodInsnNode.owner;
324                        String desc = methodInsnNode.desc;
325                        String name = methodInsnNode.name;
326                        //TODO check closure
327                        int paramLength = Type.getArgumentTypes(desc).length + 1;//non static
328                        if (isInvokeOnLambda(owner, name) /*&& methodInsnNode.owner.equals(INLINE_RUNTIME)*/) {
329                            SourceValue sourceValue = frame.getStack(frame.getStackSize() - paramLength);
330    
331                            LambdaInfo lambdaInfo = null;
332                            int varIndex = -1;
333    
334                            if (sourceValue.insns.size() == 1) {
335                                AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
336    
337                                lambdaInfo = getLambdaIfExists(insnNode);
338                                if (lambdaInfo != null) {
339                                    //remove inlinable access
340                                    node.instructions.remove(insnNode);
341                                }
342                            }
343    
344                            invokeCalls.add(new InvokeCall(varIndex, lambdaInfo));
345                        }
346                        else if (isLambdaConstructorCall(owner, name)) {
347                            Map<Integer, LambdaInfo> lambdaMapping = new HashMap<Integer, LambdaInfo>();
348                            int paramStart = frame.getStackSize() - paramLength;
349    
350                            for (int i = 0; i < paramLength; i++) {
351                                SourceValue sourceValue = frame.getStack(paramStart + i);
352                                if (sourceValue.insns.size() == 1) {
353                                    AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
354                                    LambdaInfo lambdaInfo = getLambdaIfExists(insnNode);
355                                    if (lambdaInfo != null) {
356                                        lambdaMapping.put(i, lambdaInfo);
357                                        node.instructions.remove(insnNode);
358                                    }
359                                }
360                            }
361    
362                            constructorInvocations.add(new ConstructorInvocation(owner, desc, lambdaMapping, isSameModule, inliningContext.classRegeneration));
363                        }
364                    }
365                }
366    
367                AbstractInsnNode prevNode = cur;
368                cur = cur.getNext();
369                index++;
370    
371                //given frame is <tt>null</tt> if and only if the corresponding instruction cannot be reached (dead code).
372                if (frame == null) {
373                    //clean dead code otherwise there is problems in unreachable finally block, don't touch label it cause try/catch/finally problems
374                    if (prevNode.getType() == AbstractInsnNode.LABEL) {
375                        deadLabels.add((LabelNode) prevNode);
376                    } else {
377                        node.instructions.remove(prevNode);
378                    }
379                }
380            }
381    
382            //clean dead try/catch blocks
383            List<TryCatchBlockNode> blocks = node.tryCatchBlocks;
384            for (Iterator<TryCatchBlockNode> iterator = blocks.iterator(); iterator.hasNext(); ) {
385                TryCatchBlockNode block = iterator.next();
386                if (deadLabels.contains(block.start) && deadLabels.contains(block.end)) {
387                    iterator.remove();
388                }
389            }
390    
391            return node;
392        }
393    
394        public LambdaInfo getLambdaIfExists(AbstractInsnNode insnNode) {
395            if (insnNode.getOpcode() == Opcodes.ALOAD) {
396                int varIndex = ((VarInsnNode) insnNode).var;
397                if (varIndex < parameters.totalSize()) {
398                    return parameters.get(varIndex).getLambda();
399                }
400            }
401            else if (insnNode instanceof FieldInsnNode) {
402                FieldInsnNode fieldInsnNode = (FieldInsnNode) insnNode;
403                if (fieldInsnNode.name.startsWith("$$$")) {
404                    return findCapturedField(fieldInsnNode, nodeRemapper).getLambda();
405                }
406            }
407    
408            return null;
409        }
410    
411        private static void removeClosureAssertions(MethodNode node) {
412            AbstractInsnNode cur = node.instructions.getFirst();
413            while (cur != null && cur.getNext() != null) {
414                AbstractInsnNode next = cur.getNext();
415                if (next.getType() == AbstractInsnNode.METHOD_INSN) {
416                    MethodInsnNode methodInsnNode = (MethodInsnNode) next;
417                    if (methodInsnNode.name.equals("checkParameterIsNotNull") && methodInsnNode.owner.equals("kotlin/jvm/internal/Intrinsics")) {
418                        AbstractInsnNode prev = cur.getPrevious();
419    
420                        assert cur.getOpcode() == Opcodes.LDC : "checkParameterIsNotNull should go after LDC but " + cur;
421                        assert prev.getOpcode() == Opcodes.ALOAD : "checkParameterIsNotNull should be invoked on local var but " + prev;
422    
423                        node.instructions.remove(prev);
424                        node.instructions.remove(cur);
425                        cur = next.getNext();
426                        node.instructions.remove(next);
427                        next = cur;
428                    }
429                }
430                cur = next;
431            }
432        }
433    
434        private void transformCaptured(@NotNull MethodNode node) {
435            if (nodeRemapper.isRoot()) {
436                return;
437            }
438    
439            //Fold all captured variable chain - ALOAD 0 ALOAD this$0 GETFIELD $captured - to GETFIELD $$$$captured
440            //On future decoding this field could be inline or unfolded in another field access chain (it can differ in some missed this$0)
441            AbstractInsnNode cur = node.instructions.getFirst();
442            while (cur != null) {
443                if (cur instanceof VarInsnNode && cur.getOpcode() == Opcodes.ALOAD) {
444                    if (((VarInsnNode) cur).var == 0) {
445                        List<AbstractInsnNode> accessChain = getCapturedFieldAccessChain((VarInsnNode) cur);
446                        AbstractInsnNode insnNode = nodeRemapper.foldFieldAccessChainIfNeeded(accessChain, node);
447                        if (insnNode != null) {
448                            cur = insnNode;
449                        }
450                    }
451                }
452                cur = cur.getNext();
453            }
454        }
455    
456        @NotNull
457        public static List<AbstractInsnNode> getCapturedFieldAccessChain(@NotNull VarInsnNode aload0) {
458            List<AbstractInsnNode> fieldAccessChain = new ArrayList<AbstractInsnNode>();
459            fieldAccessChain.add(aload0);
460            AbstractInsnNode next = aload0.getNext();
461            while (next != null && next instanceof FieldInsnNode || next instanceof LabelNode) {
462                if (next instanceof LabelNode) {
463                    next = next.getNext();
464                    continue; //it will be delete on transformation
465                }
466                fieldAccessChain.add(next);
467                if ("this$0".equals(((FieldInsnNode) next).name)) {
468                    next = next.getNext();
469                }
470                else {
471                    break;
472                }
473            }
474    
475            return fieldAccessChain;
476        }
477    
478        public static void putStackValuesIntoLocals(List<Type> directOrder, int shift, InstructionAdapter iv, String descriptor) {
479            Type[] actualParams = Type.getArgumentTypes(descriptor);
480            assert actualParams.length == directOrder.size() : "Number of expected and actual params should be equals!";
481    
482            int size = 0;
483            for (Type next : directOrder) {
484                size += next.getSize();
485            }
486    
487            shift += size;
488            int index = directOrder.size();
489    
490            for (Type next : Lists.reverse(directOrder)) {
491                shift -= next.getSize();
492                Type typeOnStack = actualParams[--index];
493                if (!typeOnStack.equals(next)) {
494                    StackValue.onStack(typeOnStack).put(next, iv);
495                }
496                iv.store(shift, next);
497            }
498        }
499    
500        //TODO: check annotation on class - it's package part
501        //TODO: check it's external module
502        //TODO?: assert method exists in facade?
503        public String changeOwnerForExternalPackage(String type, int opcode) {
504            if (isSameModule || (opcode & Opcodes.INVOKESTATIC) == 0) {
505                return type;
506            }
507    
508            int i = type.indexOf('-');
509            if (i >= 0) {
510                return type.substring(0, i);
511            }
512            return type;
513        }
514    
515    
516        public RuntimeException wrapException(@NotNull Exception originalException, @NotNull MethodNode node, @NotNull String errorSuffix) {
517            if (originalException instanceof InlineException) {
518                return new InlineException(errorPrefix + ": " + errorSuffix, originalException);
519            } else {
520                return new InlineException(errorPrefix + ": " + errorSuffix + "\ncause: " +
521                                           InlineCodegen.getNodeText(node), originalException);
522            }
523        }
524    }