001    package org.jetbrains.jet.codegen.inline;
002    
003    import com.google.common.collect.Lists;
004    import com.intellij.util.ArrayUtil;
005    import org.jetbrains.annotations.NotNull;
006    import org.jetbrains.annotations.Nullable;
007    import org.jetbrains.asm4.Label;
008    import org.jetbrains.asm4.MethodVisitor;
009    import org.jetbrains.asm4.Opcodes;
010    import org.jetbrains.asm4.Type;
011    import org.jetbrains.asm4.commons.InstructionAdapter;
012    import org.jetbrains.asm4.commons.Method;
013    import org.jetbrains.asm4.commons.RemappingMethodAdapter;
014    import org.jetbrains.asm4.tree.*;
015    import org.jetbrains.asm4.tree.analysis.*;
016    import org.jetbrains.jet.codegen.ClosureCodegen;
017    import org.jetbrains.jet.codegen.StackValue;
018    import org.jetbrains.jet.codegen.state.JetTypeMapper;
019    import org.jetbrains.jet.utils.UtilsPackage;
020    
021    import java.util.*;
022    
023    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.*;
024    
025    public class MethodInliner {
026    
027        private final MethodNode node;
028    
029        private final Parameters parameters;
030    
031        private final InliningContext parent;
032    
033        @Nullable
034        private final Type lambdaType;
035    
036        private final LambdaFieldRemapper lambdaFieldRemapper;
037    
038        private final boolean isSameModule;
039    
040        private final JetTypeMapper typeMapper;
041    
042        private final List<InvokeCall> invokeCalls = new ArrayList<InvokeCall>();
043    
044        //keeps order
045        private final List<ConstructorInvocation> constructorInvocations = new ArrayList<ConstructorInvocation>();
046        //current state
047        private final Map<String, String> currentTypeMapping = new HashMap<String, String>();
048    
049        /*
050         *
051         * @param node
052         * @param parameters
053         * @param parent
054         * @param lambdaType - in case on lambda 'invoke' inlining
055         */
056        public MethodInliner(
057                @NotNull MethodNode node,
058                @NotNull Parameters parameters,
059                @NotNull InliningContext parent,
060                @Nullable Type lambdaType,
061                LambdaFieldRemapper lambdaFieldRemapper,
062                boolean isSameModule
063        ) {
064            this.node = node;
065            this.parameters = parameters;
066            this.parent = parent;
067            this.lambdaType = lambdaType;
068            this.lambdaFieldRemapper = lambdaFieldRemapper;
069            this.isSameModule = isSameModule;
070            this.typeMapper = parent.state.getTypeMapper();
071        }
072    
073    
074        public void doInline(MethodVisitor adapter, VarRemapper.ParamRemapper remapper) {
075            doInline(adapter, remapper, new LambdaFieldRemapper(), true);
076        }
077    
078        public void doInline(
079                MethodVisitor adapter,
080                VarRemapper.ParamRemapper remapper,
081                LambdaFieldRemapper capturedRemapper, boolean remapReturn
082        ) {
083            //analyze body
084            MethodNode transformedNode = node;
085            try {
086                transformedNode = markPlacesForInlineAndRemoveInlinable(transformedNode);
087            }
088            catch (AnalyzerException e) {
089                throw UtilsPackage.rethrow(e);
090            }
091    
092            transformedNode = doInline(transformedNode, capturedRemapper);
093            removeClosureAssertions(transformedNode);
094            transformedNode.instructions.resetLabels();
095    
096            Label end = new Label();
097            RemapVisitor visitor = new RemapVisitor(adapter, end, remapper, remapReturn);
098            transformedNode.accept(visitor);
099            visitor.visitLabel(end);
100    
101        }
102    
103        private MethodNode doInline(MethodNode node, final LambdaFieldRemapper capturedRemapper) {
104    
105            final Deque<InvokeCall> currentInvokes = new LinkedList<InvokeCall>(invokeCalls);
106    
107            MethodNode resultNode = new MethodNode(node.access, node.name, node.desc, node.signature, null);
108    
109            final Iterator<ConstructorInvocation> iterator = constructorInvocations.iterator();
110    
111            RemappingMethodAdapter remappingMethodAdapter = new RemappingMethodAdapter(resultNode.access, resultNode.desc, resultNode,
112                                                                                       new TypeRemapper(currentTypeMapping));
113    
114            InlineAdapter inliner = new InlineAdapter(remappingMethodAdapter, parameters.totalSize()) {
115    
116                private ConstructorInvocation invocation;
117                @Override
118                public void anew(Type type) {
119                    if (isLambdaConstructorCall(type.getInternalName(), "<init>")) {
120                        invocation = iterator.next();
121    
122                        if (invocation.shouldRegenerate()) {
123                            //TODO: need poping of type but what to do with local funs???
124                            Type newLambdaType = Type.getObjectType(parent.nameGenerator.genLambdaClassName());
125                            currentTypeMapping.put(invocation.getOwnerInternalName(), newLambdaType.getInternalName());
126                            LambdaTransformer transformer = new LambdaTransformer(invocation.getOwnerInternalName(),
127                                                                                  parent.subInline(parent.nameGenerator, currentTypeMapping),
128                                                                                  isSameModule, newLambdaType);
129    
130                            transformer.doTransform(invocation);
131                        }
132                    }
133    
134                    //in case of regenerated invocation type would be remapped to new one via remappingMethodAdapter
135                    super.anew(type);
136                }
137    
138                @Override
139                public void visitMethodInsn(int opcode, String owner, String name, String desc) {
140                    if (/*INLINE_RUNTIME.equals(owner) &&*/ isInvokeOnLambda(owner, name)) { //TODO add method
141                        assert !currentInvokes.isEmpty();
142                        InvokeCall invokeCall = currentInvokes.remove();
143                        LambdaInfo info = invokeCall.lambdaInfo;
144    
145                        if (info == null) {
146                            //noninlinable lambda
147                            super.visitMethodInsn(opcode, owner, name, desc);
148                            return;
149                        }
150    
151                        int valueParamShift = getNextLocalIndex();//NB: don't inline cause it changes
152                        putStackValuesIntoLocals(info.getParamsWithoutCapturedValOrVar(), valueParamShift, this, desc);
153    
154                        Parameters lambdaParameters = info.addAllParameters(capturedRemapper);
155    
156                        setInlining(true);
157                        MethodInliner inliner = new MethodInliner(info.getNode(), lambdaParameters, parent.subInline(parent.nameGenerator.subGenerator("lambda")), info.getLambdaClassType(),
158                                                                  capturedRemapper, true /*cause all calls in same module as lambda*/);
159    
160                        VarRemapper.ParamRemapper remapper = new VarRemapper.ParamRemapper(lambdaParameters, valueParamShift);
161                        inliner.doInline(this.mv, remapper); //TODO add skipped this and receiver
162    
163                        //return value boxing/unboxing
164                        Method bridge = typeMapper.mapSignature(ClosureCodegen.getInvokeFunction(info.getFunctionDescriptor())).getAsmMethod();
165                        Method delegate = typeMapper.mapSignature(info.getFunctionDescriptor()).getAsmMethod();
166                        StackValue.onStack(delegate.getReturnType()).put(bridge.getReturnType(), this);
167                        setInlining(false);
168                    }
169                    else if (isLambdaConstructorCall(owner, name)) { //TODO add method
170                        assert invocation != null : "<init> call not corresponds to new call" + owner + " " + name;
171                        if (invocation.shouldRegenerate()) {
172                            //put additional captured parameters on stack
173                            List<CapturedParamInfo> recaptured = invocation.getAllRecapturedParameters();
174                            List<CapturedParamInfo> contextCaptured = MethodInliner.this.parameters.getCaptured();
175                            for (CapturedParamInfo capturedParamInfo : recaptured) {
176                                CapturedParamInfo result = null;
177                                for (CapturedParamInfo info : contextCaptured) {
178                                    //TODO more sophisticated check
179                                    if (info.getFieldName().equals(capturedParamInfo.getFieldName())) {
180                                        result = info;
181                                    }
182                                }
183                                if (result == null) {
184                                    throw new UnsupportedOperationException(
185                                            "Unsupported operation: could not transform non-inline lambda inside inlined one: " +
186                                            owner + "." + name);
187                                }
188                                super.visitVarInsn(capturedParamInfo.getType().getOpcode(Opcodes.ILOAD), result.getIndex());
189                            }
190                            super.visitMethodInsn(opcode, invocation.getNewLambdaType().getInternalName(), name, invocation.getNewConstructorDescriptor());
191                            invocation = null;
192                        } else {
193                            super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc);
194                        }
195                    }
196                    else {
197                        super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc);
198                    }
199                }
200            };
201    
202            node.accept(inliner);
203    
204            return resultNode;
205        }
206    
207        public void merge() {
208    
209        }
210    
211        @NotNull
212        public MethodNode prepareNode(@NotNull MethodNode node) {
213            final int capturedParamsSize = parameters.getCaptured().size();
214            final int realParametersSize = parameters.getReal().size();
215            Type[] types = Type.getArgumentTypes(node.desc);
216            Type returnType = Type.getReturnType(node.desc);
217    
218            ArrayList<Type> capturedTypes = parameters.getCapturedTypes();
219            Type[] allTypes = ArrayUtil.mergeArrays(types, capturedTypes.toArray(new Type[capturedTypes.size()]));
220    
221            node.instructions.resetLabels();
222            MethodNode transformedNode = new MethodNode(node.access, node.name, Type.getMethodDescriptor(returnType, allTypes), node.signature, null) {
223    
224                @Override
225                public void visitVarInsn(int opcode, int var) {
226                    int newIndex;
227                    if (var < realParametersSize) {
228                        newIndex = var;
229                    } else {
230                        newIndex = var + capturedParamsSize;
231                    }
232                    super.visitVarInsn(opcode, newIndex);
233                }
234    
235                @Override
236                public void visitIincInsn(int var, int increment) {
237                    int newIndex;
238                    if (var < realParametersSize) {
239                        newIndex = var;
240                    } else {
241                        newIndex = var + capturedParamsSize;
242                    }
243                    super.visitIincInsn(newIndex, increment);
244                }
245    
246                @Override
247                public void visitMaxs(int maxStack, int maxLocals) {
248                    super.visitMaxs(maxStack, maxLocals + capturedParamsSize);
249                }
250            };
251    
252            node.accept(transformedNode);
253    
254            transformCaptured(transformedNode);
255    
256            return transformedNode;
257        }
258    
259        @NotNull
260        protected MethodNode markPlacesForInlineAndRemoveInlinable(@NotNull MethodNode node) throws AnalyzerException {
261            node = prepareNode(node);
262    
263            Analyzer<SourceValue> analyzer = new Analyzer<SourceValue>(new SourceInterpreter());
264            Frame<SourceValue>[] sources = analyzer.analyze("fake", node);
265    
266            AbstractInsnNode cur = node.instructions.getFirst();
267            int index = 0;
268            Set<LabelNode> deadLabels = new HashSet<LabelNode>();
269    
270            while (cur != null) {
271                Frame<SourceValue> frame = sources[index];
272    
273                if (frame != null) {
274                    if (cur.getType() == AbstractInsnNode.METHOD_INSN) {
275                        MethodInsnNode methodInsnNode = (MethodInsnNode) cur;
276                        String owner = methodInsnNode.owner;
277                        String desc = methodInsnNode.desc;
278                        String name = methodInsnNode.name;
279                        //TODO check closure
280                        int paramLength = Type.getArgumentTypes(desc).length + 1;//non static
281                        if (isInvokeOnLambda(owner, name) /*&& methodInsnNode.owner.equals(INLINE_RUNTIME)*/) {
282                            SourceValue sourceValue = frame.getStack(frame.getStackSize() - paramLength);
283    
284                            LambdaInfo lambdaInfo = null;
285                            int varIndex = -1;
286    
287                            if (sourceValue.insns.size() == 1) {
288                                AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
289                                if (insnNode.getType() == AbstractInsnNode.VAR_INSN) {
290                                    assert insnNode.getOpcode() == Opcodes.ALOAD : insnNode.toString();
291                                    varIndex = ((VarInsnNode) insnNode).var;
292                                    lambdaInfo = getLambda(varIndex);
293    
294                                    if (lambdaInfo != null) {
295                                        //remove inlinable access
296                                        node.instructions.remove(insnNode);
297                                    }
298                                }
299                            }
300    
301                            invokeCalls.add(new InvokeCall(varIndex, lambdaInfo));
302                        }
303                        else if (isLambdaConstructorCall(owner, name)) {
304                            Map<Integer, LambdaInfo> lambdaMapping = new HashMap<Integer, LambdaInfo>();
305                            int paramStart = frame.getStackSize() - paramLength;
306    
307                            for (int i = 0; i < paramLength; i++) {
308                                SourceValue sourceValue = frame.getStack(paramStart + i);
309                                if (sourceValue.insns.size() == 1) {
310                                    AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
311                                    if (insnNode.getOpcode() == Opcodes.ALOAD) {
312                                        int varIndex = ((VarInsnNode) insnNode).var;
313                                        LambdaInfo lambdaInfo = getLambda(varIndex);
314                                        if (lambdaInfo != null) {
315                                            lambdaMapping.put(i, lambdaInfo);
316                                            node.instructions.remove(insnNode);
317                                        }
318                                    }
319                                }
320                            }
321    
322                            constructorInvocations.add(new ConstructorInvocation(owner, lambdaMapping, isSameModule));
323                        }
324                    }
325                }
326    
327                AbstractInsnNode prevNode = cur;
328                cur = cur.getNext();
329                index++;
330    
331                //given frame is <tt>null</tt> if and only if the corresponding instruction cannot be reached (dead code).
332                if (frame == null) {
333                    //clean dead code otherwise there is problems in unreachable finally block, don't touch label it cause try/catch/finally problems
334                    if (prevNode.getType() == AbstractInsnNode.LABEL) {
335                        deadLabels.add((LabelNode) prevNode);
336                    } else {
337                        node.instructions.remove(prevNode);
338                    }
339                }
340            }
341    
342            //clean dead try/catch blocks
343            List<TryCatchBlockNode> blocks = node.tryCatchBlocks;
344            for (Iterator<TryCatchBlockNode> iterator = blocks.iterator(); iterator.hasNext(); ) {
345                TryCatchBlockNode block = iterator.next();
346                if (deadLabels.contains(block.start) && deadLabels.contains(block.end)) {
347                    iterator.remove();
348                }
349            }
350    
351            return node;
352        }
353    
354        @Nullable
355        public LambdaInfo getLambda(int index) {
356            if (index < parameters.totalSize()) {
357                return parameters.get(index).getLambda();
358            }
359            return null;
360        }
361    
362        private static void removeClosureAssertions(MethodNode node) {
363            AbstractInsnNode cur = node.instructions.getFirst();
364            while (cur != null && cur.getNext() != null) {
365                AbstractInsnNode next = cur.getNext();
366                if (next.getType() == AbstractInsnNode.METHOD_INSN) {
367                    MethodInsnNode methodInsnNode = (MethodInsnNode) next;
368                    if (methodInsnNode.name.equals("checkParameterIsNotNull") && methodInsnNode.owner.equals("kotlin/jvm/internal/Intrinsics")) {
369                        AbstractInsnNode prev = cur.getPrevious();
370    
371                        assert cur.getOpcode() == Opcodes.LDC : "checkParameterIsNotNull should go after LDC but " + cur;
372                        assert prev.getOpcode() == Opcodes.ALOAD : "checkParameterIsNotNull should be invoked on local var but " + prev;
373    
374                        node.instructions.remove(prev);
375                        node.instructions.remove(cur);
376                        cur = next.getNext();
377                        node.instructions.remove(next);
378                        next = cur;
379                    }
380                }
381                cur = next;
382            }
383        }
384    
385        private void transformCaptured(@NotNull MethodNode node) {
386            if (lambdaType == null) {
387                return;
388            }
389    
390            //remove all this and shift all variables to captured ones size
391            AbstractInsnNode cur = node.instructions.getFirst();
392            while (cur != null) {
393                if (cur.getType() == AbstractInsnNode.FIELD_INSN) {
394                    FieldInsnNode fieldInsnNode = (FieldInsnNode) cur;
395                    //TODO check closure
396                    if (lambdaFieldRemapper.canProcess(fieldInsnNode.owner, lambdaType.getInternalName())) {
397                        CapturedParamInfo result = this.lambdaFieldRemapper.findField(fieldInsnNode, parameters.getCaptured());
398    
399                        if (result == null) {
400                            throw new UnsupportedOperationException("Coudn't find field " +
401                                                                    fieldInsnNode.owner +
402                                                                    "." +
403                                                                    fieldInsnNode.name +
404                                                                    " (" +
405                                                                    fieldInsnNode.desc +
406                                                                    ") in captured vars of " + lambdaType);
407                        }
408    
409                        if (result.isSkipped()) {
410                            //lambda class transformation: skip captured this
411                        } else {
412                            cur = this.lambdaFieldRemapper.doTransform(node, fieldInsnNode, result);
413                        }
414                    }
415                }
416                cur = cur.getNext();
417            }
418        }
419    
420        public static AbstractInsnNode getPreviousNoLabelNoLine(AbstractInsnNode cur) {
421            AbstractInsnNode prev = cur.getPrevious();
422            while (prev.getType() == AbstractInsnNode.LABEL || prev.getType() == AbstractInsnNode.LINE) {
423                prev = prev.getPrevious();
424            }
425            return prev;
426        }
427    
428        public static void putStackValuesIntoLocals(List<Type> directOrder, int shift, InstructionAdapter iv, String descriptor) {
429            Type[] actualParams = Type.getArgumentTypes(descriptor);
430            assert actualParams.length == directOrder.size() : "Number of expected and actual params should be equals!";
431    
432            int size = 0;
433            for (Type next : directOrder) {
434                size += next.getSize();
435            }
436    
437            shift += size;
438            int index = directOrder.size();
439    
440            for (Type next : Lists.reverse(directOrder)) {
441                shift -= next.getSize();
442                Type typeOnStack = actualParams[--index];
443                if (!typeOnStack.equals(next)) {
444                    StackValue.onStack(typeOnStack).put(next, iv);
445                }
446                iv.store(shift, next);
447            }
448        }
449    
450        //TODO: check annotation on class - it's package part
451        //TODO: check it's external module
452        //TODO?: assert method exists in facade?
453        public String changeOwnerForExternalPackage(String type, int opcode) {
454            if (isSameModule || (opcode & Opcodes.INVOKESTATIC) == 0) {
455                return type;
456            }
457    
458            int i = type.indexOf('-');
459            if (i >= 0) {
460                return type.substring(0, i);
461            }
462            return type;
463        }
464    }