001    package org.jetbrains.jet.codegen.inline;
002    
003    import com.google.common.collect.Lists;
004    import com.intellij.util.ArrayUtil;
005    import org.jetbrains.annotations.NotNull;
006    import org.jetbrains.annotations.Nullable;
007    import org.jetbrains.asm4.Label;
008    import org.jetbrains.asm4.MethodVisitor;
009    import org.jetbrains.asm4.Opcodes;
010    import org.jetbrains.asm4.Type;
011    import org.jetbrains.asm4.commons.InstructionAdapter;
012    import org.jetbrains.asm4.commons.Method;
013    import org.jetbrains.asm4.commons.RemappingMethodAdapter;
014    import org.jetbrains.asm4.tree.*;
015    import org.jetbrains.asm4.tree.analysis.*;
016    import org.jetbrains.jet.codegen.AsmUtil;
017    import org.jetbrains.jet.codegen.ClosureCodegen;
018    import org.jetbrains.jet.codegen.StackValue;
019    import org.jetbrains.jet.codegen.state.JetTypeMapper;
020    import org.jetbrains.jet.utils.UtilsPackage;
021    
022    import java.util.*;
023    
024    import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.*;
025    
026    public class MethodInliner {
027    
028        private final MethodNode node;
029    
030        private final Parameters parameters;
031    
032        private final InliningContext inliningContext;
033    
034        @Nullable
035        private final Type lambdaType;
036    
037        private final LambdaFieldRemapper lambdaFieldRemapper;
038    
039        private final boolean isSameModule;
040    
041        private final JetTypeMapper typeMapper;
042    
043        private final List<InvokeCall> invokeCalls = new ArrayList<InvokeCall>();
044    
045        //keeps order
046        private final List<ConstructorInvocation> constructorInvocations = new ArrayList<ConstructorInvocation>();
047        //current state
048        private final Map<String, String> currentTypeMapping = new HashMap<String, String>();
049    
050        private final InlineResult result;
051    
052        /*
053         *
054         * @param node
055         * @param parameters
056         * @param inliningContext
057         * @param lambdaType - in case on lambda 'invoke' inlining
058         */
059        public MethodInliner(
060                @NotNull MethodNode node,
061                @NotNull Parameters parameters,
062                @NotNull InliningContext parent,
063                @Nullable Type lambdaType,
064                LambdaFieldRemapper lambdaFieldRemapper,
065                boolean isSameModule
066        ) {
067            this.node = node;
068            this.parameters = parameters;
069            this.inliningContext = parent;
070            this.lambdaType = lambdaType;
071            this.lambdaFieldRemapper = lambdaFieldRemapper;
072            this.isSameModule = isSameModule;
073            this.typeMapper = parent.state.getTypeMapper();
074            this.result = InlineResult.create();
075        }
076    
077    
078        public InlineResult doInline(MethodVisitor adapter, VarRemapper.ParamRemapper remapper) {
079            return doInline(adapter, remapper, lambdaFieldRemapper, true);
080        }
081    
082        public InlineResult doInline(
083                MethodVisitor adapter,
084                VarRemapper.ParamRemapper remapper,
085                LambdaFieldRemapper capturedRemapper, boolean remapReturn
086        ) {
087            //analyze body
088            MethodNode transformedNode = node;
089            try {
090                transformedNode = markPlacesForInlineAndRemoveInlinable(transformedNode);
091            }
092            catch (AnalyzerException e) {
093                throw UtilsPackage.rethrow(e);
094            }
095    
096            transformedNode = doInline(transformedNode, capturedRemapper);
097            removeClosureAssertions(transformedNode);
098            transformedNode.instructions.resetLabels();
099    
100            Label end = new Label();
101            RemapVisitor visitor = new RemapVisitor(adapter, end, remapper, remapReturn);
102            transformedNode.accept(visitor);
103            visitor.visitLabel(end);
104    
105            return result;
106        }
107    
108        private MethodNode doInline(MethodNode node, final LambdaFieldRemapper capturedRemapper) {
109    
110            final Deque<InvokeCall> currentInvokes = new LinkedList<InvokeCall>(invokeCalls);
111    
112            MethodNode resultNode = new MethodNode(node.access, node.name, node.desc, node.signature, null);
113    
114            final Iterator<ConstructorInvocation> iterator = constructorInvocations.iterator();
115    
116            RemappingMethodAdapter remappingMethodAdapter = new RemappingMethodAdapter(resultNode.access, resultNode.desc, resultNode,
117                                                                                       new TypeRemapper(currentTypeMapping));
118    
119            InlineAdapter inliner = new InlineAdapter(remappingMethodAdapter, parameters.totalSize()) {
120    
121                private ConstructorInvocation invocation;
122                @Override
123                public void anew(Type type) {
124                    if (isLambdaConstructorCall(type.getInternalName(), "<init>")) {
125                        invocation = iterator.next();
126    
127                        if (invocation.shouldRegenerate()) {
128                            //TODO: need poping of type but what to do with local funs???
129                            Type newLambdaType = Type.getObjectType(inliningContext.nameGenerator.genLambdaClassName());
130                            currentTypeMapping.put(invocation.getOwnerInternalName(), newLambdaType.getInternalName());
131                            LambdaTransformer transformer = new LambdaTransformer(invocation.getOwnerInternalName(),
132                                                                                  inliningContext.subInline(inliningContext.nameGenerator, currentTypeMapping).classRegeneration(),
133                                                                                  isSameModule, newLambdaType);
134    
135                            InlineResult transformResult = transformer.doTransform(invocation, capturedRemapper);
136                            result.addAllClassesToRemove(transformResult);
137    
138                            if (inliningContext.isInliningLambda) {
139                                //this class is transformed and original not used so we should remove original one after inlining
140                                result.addClassToRemove(invocation.getOwnerInternalName());
141                            }
142                        }
143                    }
144    
145                    //in case of regenerated invocation type would be remapped to new one via remappingMethodAdapter
146                    super.anew(type);
147                }
148    
149                @Override
150                public void visitMethodInsn(int opcode, String owner, String name, String desc) {
151                    if (/*INLINE_RUNTIME.equals(owner) &&*/ isInvokeOnLambda(owner, name)) { //TODO add method
152                        assert !currentInvokes.isEmpty();
153                        InvokeCall invokeCall = currentInvokes.remove();
154                        LambdaInfo info = invokeCall.lambdaInfo;
155    
156                        if (info == null) {
157                            //noninlinable lambda
158                            super.visitMethodInsn(opcode, owner, name, desc);
159                            return;
160                        }
161    
162                        int valueParamShift = getNextLocalIndex();//NB: don't inline cause it changes
163                        putStackValuesIntoLocals(info.getParamsWithoutCapturedValOrVar(), valueParamShift, this, desc);
164    
165                        Parameters lambdaParameters = info.addAllParameters(capturedRemapper);
166    
167                        setInlining(true);
168                        MethodInliner inliner = new MethodInliner(info.getNode(), lambdaParameters,
169                                                                  inliningContext.subInlineLambda(info),
170                                                                  info.getLambdaClassType(),
171                                                                  capturedRemapper, true /*cause all calls in same module as lambda*/
172                        );
173    
174                        VarRemapper.ParamRemapper remapper = new VarRemapper.ParamRemapper(lambdaParameters, valueParamShift);
175                        InlineResult lambdaResult = inliner.doInline(this.mv, remapper);//TODO add skipped this and receiver
176                        result.addAllClassesToRemove(lambdaResult);
177    
178                        //return value boxing/unboxing
179                        Method bridge = typeMapper.mapSignature(ClosureCodegen.getInvokeFunction(info.getFunctionDescriptor())).getAsmMethod();
180                        Method delegate = typeMapper.mapSignature(info.getFunctionDescriptor()).getAsmMethod();
181                        StackValue.onStack(delegate.getReturnType()).put(bridge.getReturnType(), this);
182                        setInlining(false);
183                    }
184                    else if (isLambdaConstructorCall(owner, name)) { //TODO add method
185                        assert invocation != null : "<init> call not corresponds to new call" + owner + " " + name;
186                        if (invocation.shouldRegenerate()) {
187                            //put additional captured parameters on stack
188                            List<CapturedParamInfo> recaptured = invocation.getAllRecapturedParameters();
189                            List<CapturedParamInfo> contextCaptured = MethodInliner.this.parameters.getCaptured();
190                            for (CapturedParamInfo capturedParamInfo : recaptured) {
191                                CapturedParamInfo result = null;
192                                for (CapturedParamInfo info : contextCaptured) {
193                                    //TODO more sophisticated check
194                                    if (info.getFieldName().equals(capturedParamInfo.getFieldName())) {
195                                        result = info;
196                                    }
197                                }
198                                if (result == null) {
199                                    throw new UnsupportedOperationException(
200                                            "Unsupported operation: could not transform non-inline lambda inside inlined one: " +
201                                            owner + "." + name);
202                                }
203                                super.visitVarInsn(capturedParamInfo.getType().getOpcode(Opcodes.ILOAD), result.getIndex());
204                            }
205                            super.visitMethodInsn(opcode, invocation.getNewLambdaType().getInternalName(), name, invocation.getNewConstructorDescriptor());
206                            invocation = null;
207                        } else {
208                            super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc);
209                        }
210                    }
211                    else {
212                        super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc);
213                    }
214                }
215            };
216    
217            node.accept(inliner);
218    
219            return resultNode;
220        }
221    
222        public void merge() {
223    
224        }
225    
226        @NotNull
227        public MethodNode prepareNode(@NotNull MethodNode node) {
228            final int capturedParamsSize = parameters.getCaptured().size();
229            final int realParametersSize = parameters.getReal().size();
230            Type[] types = Type.getArgumentTypes(node.desc);
231            Type returnType = Type.getReturnType(node.desc);
232    
233            ArrayList<Type> capturedTypes = parameters.getCapturedTypes();
234            Type[] allTypes = ArrayUtil.mergeArrays(types, capturedTypes.toArray(new Type[capturedTypes.size()]));
235    
236            node.instructions.resetLabels();
237            MethodNode transformedNode = new MethodNode(node.access, node.name, Type.getMethodDescriptor(returnType, allTypes), node.signature, null) {
238    
239                @Override
240                public void visitVarInsn(int opcode, int var) {
241                    int newIndex;
242                    if (var < realParametersSize) {
243                        newIndex = var;
244                    } else {
245                        newIndex = var + capturedParamsSize;
246                    }
247                    super.visitVarInsn(opcode, newIndex);
248                }
249    
250                @Override
251                public void visitIincInsn(int var, int increment) {
252                    int newIndex;
253                    if (var < realParametersSize) {
254                        newIndex = var;
255                    } else {
256                        newIndex = var + capturedParamsSize;
257                    }
258                    super.visitIincInsn(newIndex, increment);
259                }
260    
261                @Override
262                public void visitMaxs(int maxStack, int maxLocals) {
263                    super.visitMaxs(maxStack, maxLocals + capturedParamsSize);
264                }
265            };
266    
267            node.accept(transformedNode);
268    
269            transformCaptured(transformedNode);
270    
271            return transformedNode;
272        }
273    
274        @NotNull
275        protected MethodNode markPlacesForInlineAndRemoveInlinable(@NotNull MethodNode node) throws AnalyzerException {
276            node = prepareNode(node);
277    
278            Analyzer<SourceValue> analyzer = new Analyzer<SourceValue>(new SourceInterpreter());
279            Frame<SourceValue>[] sources = analyzer.analyze("fake", node);
280    
281            AbstractInsnNode cur = node.instructions.getFirst();
282            int index = 0;
283            Set<LabelNode> deadLabels = new HashSet<LabelNode>();
284    
285            while (cur != null) {
286                Frame<SourceValue> frame = sources[index];
287    
288                if (frame != null) {
289                    if (cur.getType() == AbstractInsnNode.METHOD_INSN) {
290                        MethodInsnNode methodInsnNode = (MethodInsnNode) cur;
291                        String owner = methodInsnNode.owner;
292                        String desc = methodInsnNode.desc;
293                        String name = methodInsnNode.name;
294                        //TODO check closure
295                        int paramLength = Type.getArgumentTypes(desc).length + 1;//non static
296                        if (isInvokeOnLambda(owner, name) /*&& methodInsnNode.owner.equals(INLINE_RUNTIME)*/) {
297                            SourceValue sourceValue = frame.getStack(frame.getStackSize() - paramLength);
298    
299                            LambdaInfo lambdaInfo = null;
300                            int varIndex = -1;
301    
302                            if (sourceValue.insns.size() == 1) {
303                                AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
304                                if (insnNode.getType() == AbstractInsnNode.VAR_INSN) {
305                                    assert insnNode.getOpcode() == Opcodes.ALOAD : insnNode.toString();
306                                    varIndex = ((VarInsnNode) insnNode).var;
307                                    lambdaInfo = getLambda(varIndex);
308    
309                                    if (lambdaInfo != null) {
310                                        //remove inlinable access
311                                        node.instructions.remove(insnNode);
312                                    }
313                                }
314                            }
315    
316                            invokeCalls.add(new InvokeCall(varIndex, lambdaInfo));
317                        }
318                        else if (isLambdaConstructorCall(owner, name)) {
319                            Map<Integer, LambdaInfo> lambdaMapping = new HashMap<Integer, LambdaInfo>();
320                            int paramStart = frame.getStackSize() - paramLength;
321    
322                            for (int i = 0; i < paramLength; i++) {
323                                SourceValue sourceValue = frame.getStack(paramStart + i);
324                                if (sourceValue.insns.size() == 1) {
325                                    AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
326                                    if (insnNode.getOpcode() == Opcodes.ALOAD) {
327                                        int varIndex = ((VarInsnNode) insnNode).var;
328                                        LambdaInfo lambdaInfo = getLambda(varIndex);
329                                        if (lambdaInfo != null) {
330                                            lambdaMapping.put(i, lambdaInfo);
331                                            node.instructions.remove(insnNode);
332                                        }
333                                    }
334                                }
335                            }
336    
337                            constructorInvocations.add(new ConstructorInvocation(owner, lambdaMapping, isSameModule, inliningContext.classRegeneration));
338                        }
339                    }
340                }
341    
342                AbstractInsnNode prevNode = cur;
343                cur = cur.getNext();
344                index++;
345    
346                //given frame is <tt>null</tt> if and only if the corresponding instruction cannot be reached (dead code).
347                if (frame == null) {
348                    //clean dead code otherwise there is problems in unreachable finally block, don't touch label it cause try/catch/finally problems
349                    if (prevNode.getType() == AbstractInsnNode.LABEL) {
350                        deadLabels.add((LabelNode) prevNode);
351                    } else {
352                        node.instructions.remove(prevNode);
353                    }
354                }
355            }
356    
357            //clean dead try/catch blocks
358            List<TryCatchBlockNode> blocks = node.tryCatchBlocks;
359            for (Iterator<TryCatchBlockNode> iterator = blocks.iterator(); iterator.hasNext(); ) {
360                TryCatchBlockNode block = iterator.next();
361                if (deadLabels.contains(block.start) && deadLabels.contains(block.end)) {
362                    iterator.remove();
363                }
364            }
365    
366            return node;
367        }
368    
369        @Nullable
370        public LambdaInfo getLambda(int index) {
371            if (index < parameters.totalSize()) {
372                return parameters.get(index).getLambda();
373            }
374            return null;
375        }
376    
377        private static void removeClosureAssertions(MethodNode node) {
378            AbstractInsnNode cur = node.instructions.getFirst();
379            while (cur != null && cur.getNext() != null) {
380                AbstractInsnNode next = cur.getNext();
381                if (next.getType() == AbstractInsnNode.METHOD_INSN) {
382                    MethodInsnNode methodInsnNode = (MethodInsnNode) next;
383                    if (methodInsnNode.name.equals("checkParameterIsNotNull") && methodInsnNode.owner.equals("kotlin/jvm/internal/Intrinsics")) {
384                        AbstractInsnNode prev = cur.getPrevious();
385    
386                        assert cur.getOpcode() == Opcodes.LDC : "checkParameterIsNotNull should go after LDC but " + cur;
387                        assert prev.getOpcode() == Opcodes.ALOAD : "checkParameterIsNotNull should be invoked on local var but " + prev;
388    
389                        node.instructions.remove(prev);
390                        node.instructions.remove(cur);
391                        cur = next.getNext();
392                        node.instructions.remove(next);
393                        next = cur;
394                    }
395                }
396                cur = next;
397            }
398        }
399    
400        private void transformCaptured(@NotNull MethodNode node) {
401            if (lambdaType == null) {
402                return;
403            }
404    
405            //remove all this and shift all variables to captured ones size
406            AbstractInsnNode cur = node.instructions.getFirst();
407            while (cur != null) {
408                if (cur.getType() == AbstractInsnNode.FIELD_INSN) {
409                    FieldInsnNode fieldInsnNode = (FieldInsnNode) cur;
410                    //TODO check closure
411                    if (lambdaFieldRemapper.canProcess(fieldInsnNode.owner, lambdaType.getInternalName())) {
412                        CapturedParamInfo result = this.lambdaFieldRemapper.findField(fieldInsnNode, parameters.getCaptured());
413    
414                        if (result == null) {
415                            throw new UnsupportedOperationException("Coudn't find field " +
416                                                                    fieldInsnNode.owner +
417                                                                    "." +
418                                                                    fieldInsnNode.name +
419                                                                    " (" +
420                                                                    fieldInsnNode.desc +
421                                                                    ") in captured vars of " + lambdaType);
422                        }
423    
424                        if (result.isSkipped()) {
425                            //lambda class transformation: skip captured this
426                        } else {
427                            cur = this.lambdaFieldRemapper.doTransform(node, fieldInsnNode, result);
428                        }
429                    }
430                    else if (lambdaFieldRemapper.shouldPatch(fieldInsnNode)) {
431                        cur = lambdaFieldRemapper.patch(fieldInsnNode, node);
432                    }
433                }
434                cur = cur.getNext();
435            }
436        }
437    
438        public static AbstractInsnNode getPreviousNoLabelNoLine(AbstractInsnNode cur) {
439            AbstractInsnNode prev = cur.getPrevious();
440            while (prev.getType() == AbstractInsnNode.LABEL || prev.getType() == AbstractInsnNode.LINE) {
441                prev = prev.getPrevious();
442            }
443            return prev;
444        }
445    
446        public static void putStackValuesIntoLocals(List<Type> directOrder, int shift, InstructionAdapter iv, String descriptor) {
447            Type[] actualParams = Type.getArgumentTypes(descriptor);
448            assert actualParams.length == directOrder.size() : "Number of expected and actual params should be equals!";
449    
450            int size = 0;
451            for (Type next : directOrder) {
452                size += next.getSize();
453            }
454    
455            shift += size;
456            int index = directOrder.size();
457    
458            for (Type next : Lists.reverse(directOrder)) {
459                shift -= next.getSize();
460                Type typeOnStack = actualParams[--index];
461                if (!typeOnStack.equals(next)) {
462                    StackValue.onStack(typeOnStack).put(next, iv);
463                }
464                iv.store(shift, next);
465            }
466        }
467    
468        //TODO: check annotation on class - it's package part
469        //TODO: check it's external module
470        //TODO?: assert method exists in facade?
471        public String changeOwnerForExternalPackage(String type, int opcode) {
472            if (isSameModule || (opcode & Opcodes.INVOKESTATIC) == 0) {
473                return type;
474            }
475    
476            int i = type.indexOf('-');
477            if (i >= 0) {
478                return type.substring(0, i);
479            }
480            return type;
481        }
482    }