001    /*
002    * Copyright 2010-2013 JetBrains s.r.o.
003    *
004    * Licensed under the Apache License, Version 2.0 (the "License");
005    * you may not use this file except in compliance with the License.
006    * You may obtain a copy of the License at
007    *
008    * http://www.apache.org/licenses/LICENSE-2.0
009    *
010    * Unless required by applicable law or agreed to in writing, software
011    * distributed under the License is distributed on an "AS IS" BASIS,
012    * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013    * See the License for the specific language governing permissions and
014    * limitations under the License.
015    */
016    
017    package org.jetbrains.jet.codegen.inline;
018    
019    import com.intellij.openapi.vfs.VirtualFile;
020    import com.intellij.psi.PsiElement;
021    import org.jetbrains.annotations.NotNull;
022    import org.jetbrains.annotations.Nullable;
023    import org.jetbrains.jet.codegen.*;
024    import org.jetbrains.jet.codegen.binding.CodegenBinding;
025    import org.jetbrains.jet.codegen.context.CodegenContext;
026    import org.jetbrains.jet.codegen.context.MethodContext;
027    import org.jetbrains.jet.codegen.context.PackageContext;
028    import org.jetbrains.jet.codegen.state.GenerationState;
029    import org.jetbrains.jet.codegen.state.JetTypeMapper;
030    import org.jetbrains.jet.descriptors.serialization.descriptors.DeserializedSimpleFunctionDescriptor;
031    import org.jetbrains.jet.lang.descriptors.*;
032    import org.jetbrains.jet.lang.descriptors.impl.AnonymousFunctionDescriptor;
033    import org.jetbrains.jet.lang.psi.*;
034    import org.jetbrains.jet.lang.resolve.BindingContext;
035    import org.jetbrains.jet.lang.resolve.DescriptorToSourceUtils;
036    import org.jetbrains.jet.lang.resolve.DescriptorUtils;
037    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCall;
038    import org.jetbrains.jet.lang.resolve.java.AsmTypeConstants;
039    import org.jetbrains.jet.lang.resolve.java.jvmSignature.JvmMethodParameterKind;
040    import org.jetbrains.jet.lang.resolve.java.jvmSignature.JvmMethodParameterSignature;
041    import org.jetbrains.jet.lang.resolve.java.jvmSignature.JvmMethodSignature;
042    import org.jetbrains.jet.lang.types.lang.InlineStrategy;
043    import org.jetbrains.jet.lang.types.lang.InlineUtil;
044    import org.jetbrains.jet.renderer.DescriptorRenderer;
045    import org.jetbrains.org.objectweb.asm.MethodVisitor;
046    import org.jetbrains.org.objectweb.asm.Opcodes;
047    import org.jetbrains.org.objectweb.asm.Type;
048    import org.jetbrains.org.objectweb.asm.commons.Method;
049    import org.jetbrains.org.objectweb.asm.tree.MethodNode;
050    import org.jetbrains.org.objectweb.asm.util.Textifier;
051    import org.jetbrains.org.objectweb.asm.util.TraceMethodVisitor;
052    
053    import java.io.IOException;
054    import java.io.PrintWriter;
055    import java.io.StringWriter;
056    import java.util.HashMap;
057    import java.util.List;
058    import java.util.ListIterator;
059    import java.util.Map;
060    
061    import static org.jetbrains.jet.codegen.AsmUtil.*;
062    
063    public class InlineCodegen implements CallGenerator {
064        private final GenerationState state;
065        private final JetTypeMapper typeMapper;
066        private final BindingContext bindingContext;
067    
068        private final SimpleFunctionDescriptor functionDescriptor;
069        private final JvmMethodSignature jvmSignature;
070        private final JetElement callElement;
071        private final MethodContext context;
072        private final ExpressionCodegen codegen;
073    
074        private final boolean asFunctionInline;
075        private final int initialFrameSize;
076        private final boolean isSameModule;
077    
078        protected final ParametersBuilder invocationParamBuilder = ParametersBuilder.newBuilder();
079        protected final Map<Integer, LambdaInfo> expressionMap = new HashMap<Integer, LambdaInfo>();
080    
081        private LambdaInfo activeLambda;
082    
083        public InlineCodegen(
084                @NotNull ExpressionCodegen codegen,
085                @NotNull GenerationState state,
086                @NotNull SimpleFunctionDescriptor functionDescriptor,
087                @NotNull JetElement callElement
088        ) {
089            assert functionDescriptor.getInlineStrategy().isInline() : "InlineCodegen could inline only inline function but " + functionDescriptor;
090    
091            this.state = state;
092            this.typeMapper = state.getTypeMapper();
093            this.codegen = codegen;
094            this.callElement = callElement;
095            this.functionDescriptor = functionDescriptor.getOriginal();
096            bindingContext = codegen.getBindingContext();
097            initialFrameSize = codegen.getFrameMap().getCurrentSize();
098    
099            context = (MethodContext) getContext(functionDescriptor, state);
100            jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
101    
102            InlineStrategy inlineStrategy =
103                    codegen.getContext().isInlineFunction() ? InlineStrategy.IN_PLACE : functionDescriptor.getInlineStrategy();
104            this.asFunctionInline = false;
105    
106            isSameModule = JvmCodegenUtil.isCallInsideSameModuleAsDeclared(functionDescriptor, codegen.getContext(), state.getOutDirectory());
107        }
108    
109        @Override
110        public void genCallWithoutAssertions(
111                @NotNull CallableMethod callableMethod, @NotNull ExpressionCodegen codegen
112        ) {
113            genCall(callableMethod, null, false, codegen);
114        }
115    
116        @Override
117        public void genCall(@NotNull CallableMethod callableMethod, @Nullable ResolvedCall<?> resolvedCall, boolean callDefault, @NotNull ExpressionCodegen codegen) {
118            MethodNode node = null;
119    
120            try {
121                node = createMethodNode(callDefault);
122                endCall(inlineCall(node));
123            }
124            catch (CompilationException e) {
125                throw e;
126            }
127            catch (Exception e) {
128                boolean generateNodeText = !(e instanceof InlineException);
129                PsiElement element = DescriptorToSourceUtils.descriptorToDeclaration(this.codegen.getContext().getContextDescriptor());
130                throw new CompilationException("Couldn't inline method call '" +
131                                           functionDescriptor.getName() +
132                                           "' into \n" + (element != null ? element.getText() : "null psi element " + this.codegen.getContext().getContextDescriptor()) +
133                                           (generateNodeText ? ("\ncause: " + getNodeText(node)) : ""),
134                                           e, callElement);
135            }
136    
137    
138        }
139    
140        private void endCall(@NotNull InlineResult result) {
141            leaveTemps();
142    
143            state.getFactory().removeInlinedClasses(result.getClassesToRemove());
144        }
145    
146        @NotNull
147        private MethodNode createMethodNode(boolean callDefault) throws ClassNotFoundException, IOException {
148            JvmMethodSignature jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
149    
150            Method asmMethod;
151            if (callDefault) {
152                asmMethod = typeMapper.mapDefaultMethod(functionDescriptor, context.getContextKind(), context);
153            }
154            else {
155                asmMethod = jvmSignature.getAsmMethod();
156            }
157    
158            MethodNode node;
159            if (functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) {
160                VirtualFile file = InlineCodegenUtil.getVirtualFileForCallable((DeserializedSimpleFunctionDescriptor) functionDescriptor, state);
161                node = InlineCodegenUtil.getMethodNode(file.getInputStream(), asmMethod.getName(), asmMethod.getDescriptor());
162    
163                if (node == null) {
164                    throw new RuntimeException("Couldn't obtain compiled function body for " + descriptorName(functionDescriptor));
165                }
166            }
167            else {
168                PsiElement element = DescriptorToSourceUtils.descriptorToDeclaration(functionDescriptor);
169    
170                if (element == null) {
171                    throw new RuntimeException("Couldn't find declaration for function " + descriptorName(functionDescriptor));
172                }
173    
174                node = new MethodNode(InlineCodegenUtil.API,
175                                               getMethodAsmFlags(functionDescriptor, context.getContextKind()) | (callDefault ? Opcodes.ACC_STATIC : 0),
176                                               asmMethod.getName(),
177                                               asmMethod.getDescriptor(),
178                                               jvmSignature.getGenericsSignature(),
179                                               null);
180    
181                //for maxLocals calculation
182                MethodVisitor maxCalcAdapter = InlineCodegenUtil.wrapWithMaxLocalCalc(node);
183                MethodContext methodContext = context.getParentContext().intoFunction(functionDescriptor);
184                MemberCodegen<?> parentCodegen = codegen.getParentCodegen();
185                if (callDefault) {
186                    boolean isStatic = isStatic(codegen.getContext().getContextKind());
187                    FunctionCodegen.generateDefaultImplBody(
188                            methodContext, jvmSignature, functionDescriptor, isStatic, maxCalcAdapter, DefaultParameterValueLoader.DEFAULT,
189                            (JetNamedFunction) element, parentCodegen, state
190                    );
191                }
192                else {
193                    FunctionCodegen.generateMethodBody(
194                            maxCalcAdapter, functionDescriptor, methodContext, jvmSignature,
195                            new FunctionGenerationStrategy.FunctionDefault(state, functionDescriptor, (JetDeclarationWithBody) element),
196                            parentCodegen
197                    );
198                }
199                maxCalcAdapter.visitMaxs(-1, -1);
200                maxCalcAdapter.visitEnd();
201            }
202            return node;
203        }
204    
205        private InlineResult inlineCall(MethodNode node) {
206            generateClosuresBodies();
207    
208            //through generation captured parameters will be added to invocationParamBuilder
209            putClosureParametersOnStack();
210    
211            Parameters parameters = invocationParamBuilder.buildParameters();
212    
213            InliningContext info = new RootInliningContext(expressionMap,
214                                                           state,
215                                                           codegen.getInlineNameGenerator()
216                                                                   .subGenerator(functionDescriptor.getName().asString()),
217                                                           codegen.getContext(),
218                                                           callElement,
219                                                           codegen.getParentCodegen().getClassName());
220    
221            MethodInliner inliner = new MethodInliner(node, parameters, info, new FieldRemapper(null, null, parameters), isSameModule, "Method inlining " + callElement.getText()); //with captured
222    
223            LocalVarRemapper remapper = new LocalVarRemapper(parameters, initialFrameSize);
224    
225    
226            MethodNode adapter = InlineCodegenUtil.createEmptyMethodNode();
227            InlineResult result = inliner.doInline(adapter, remapper, true, LabelOwner.SKIP_ALL);
228    
229            LabelOwner labelOwner = new LabelOwner() {
230    
231                final CallableMemberDescriptor descriptor = codegen.getContext().getContextDescriptor();
232    
233                final boolean isLambda = CodegenBinding.isLocalFunOrLambda(descriptor) && descriptor.getName().isSpecial();
234    
235                @Override
236                public boolean isMyLabel(@NotNull String name) {
237                    if (InlineCodegenUtil.ROOT_LABEL.equals(name)) {
238                        return !isLambda;
239                    }
240                    else {
241                        return descriptor.getName().asString().equals(name);
242                    }
243                }
244            };
245            List<MethodInliner.FinallyBlockInfo> infos = MethodInliner.processReturns(adapter, labelOwner, true, null);
246            generateAndInsertFinallyBlocks(adapter, infos);
247    
248            adapter.accept(new InliningInstructionAdapter(codegen.v));
249            return result;
250        }
251    
252        private void generateClosuresBodies() {
253            for (LambdaInfo info : expressionMap.values()) {
254                info.setNode(generateLambdaBody(info));
255            }
256        }
257    
258        private MethodNode generateLambdaBody(LambdaInfo info) {
259            JetFunctionLiteral declaration = info.getFunctionLiteral();
260            FunctionDescriptor descriptor = info.getFunctionDescriptor();
261    
262            MethodContext parentContext = codegen.getContext();
263    
264            MethodContext context = parentContext.intoClosure(descriptor, codegen, typeMapper).intoInlinedLambda(descriptor);
265    
266            JvmMethodSignature jvmMethodSignature = typeMapper.mapSignature(descriptor);
267            Method asmMethod = jvmMethodSignature.getAsmMethod();
268            MethodNode methodNode = new MethodNode(InlineCodegenUtil.API, getMethodAsmFlags(descriptor, context.getContextKind()), asmMethod.getName(), asmMethod.getDescriptor(), jvmMethodSignature.getGenericsSignature(), null);
269    
270            MethodVisitor adapter = InlineCodegenUtil.wrapWithMaxLocalCalc(methodNode);
271    
272            FunctionCodegen.generateMethodBody(adapter, descriptor, context, jvmMethodSignature, new FunctionGenerationStrategy.FunctionDefault(state, descriptor, declaration), codegen.getParentCodegen());
273            adapter.visitMaxs(-1, -1);
274    
275            return methodNode;
276        }
277    
278    
279    
280        @Override
281        public void afterParameterPut(@NotNull Type type, @Nullable StackValue stackValue, @Nullable ValueParameterDescriptor valueParameterDescriptor) {
282            putCapturedInLocal(type, stackValue, valueParameterDescriptor, -1);
283        }
284    
285        private void putCapturedInLocal(
286                @NotNull Type type, @Nullable StackValue stackValue, @Nullable ValueParameterDescriptor valueParameterDescriptor, int capturedParamIndex
287        ) {
288            if (!asFunctionInline && Type.VOID_TYPE != type) {
289                //TODO remap only inlinable closure => otherwise we could get a lot of problem
290                boolean couldBeRemapped = !shouldPutValue(type, stackValue, valueParameterDescriptor);
291                StackValue remappedIndex = couldBeRemapped ? stackValue : null;
292    
293                ParameterInfo info;
294                if (capturedParamIndex >= 0) {
295                    CapturedParamDesc capturedParamInfoInLambda = activeLambda.getCapturedVars().get(capturedParamIndex);
296                    info = invocationParamBuilder.addCapturedParam(capturedParamInfoInLambda, capturedParamInfoInLambda.getFieldName());
297                    info.setRemapValue(remappedIndex);
298                }
299                else {
300                    info = invocationParamBuilder.addNextParameter(type, false, remappedIndex);
301                }
302    
303                putParameterOnStack(info);
304            }
305        }
306    
307        /*descriptor is null for captured vars*/
308        public boolean shouldPutValue(
309                @NotNull Type type,
310                @Nullable StackValue stackValue,
311                @Nullable ValueParameterDescriptor descriptor
312        ) {
313    
314            if (stackValue == null) {
315                //default or vararg
316                return true;
317            }
318    
319            //remap only inline functions (and maybe non primitives)
320            //TODO - clean asserion and remapping logic
321            if (isPrimitive(type) != isPrimitive(stackValue.type)) {
322                //don't remap boxing/unboxing primitives - lost identity and perfomance
323                return true;
324            }
325    
326            if (stackValue instanceof StackValue.Local) {
327                return false;
328            }
329    
330            if (stackValue instanceof StackValue.Composed) {
331                //see: Method.isSpecialStackValue: go through aload 0
332                if (codegen.getContext().isInliningLambda() && codegen.getContext().getContextDescriptor() instanceof AnonymousFunctionDescriptor) {
333                    if (descriptor != null && !InlineUtil.hasNoinlineAnnotation(descriptor)) {
334                        //TODO: check type of context
335                        return false;
336                    }
337                }
338            }
339            return true;
340        }
341    
342        private void putParameterOnStack(ParameterInfo... infos) {
343            int[] index = new int[infos.length];
344            for (int i = 0; i < infos.length; i++) {
345                ParameterInfo info = infos[i];
346                if (!info.isSkippedOrRemapped()) {
347                    index[i] = codegen.getFrameMap().enterTemp(info.getType());
348                }
349                else {
350                    index[i] = -1;
351                }
352            }
353    
354            for (int i = infos.length - 1; i >= 0; i--) {
355                ParameterInfo info = infos[i];
356                if (!info.isSkippedOrRemapped()) {
357                    Type type = info.type;
358                    StackValue.local(index[i], type).store(type, codegen.v);
359                }
360            }
361        }
362    
363        @Override
364        public void putHiddenParams() {
365            List<JvmMethodParameterSignature> valueParameters = jvmSignature.getValueParameters();
366    
367            if (!isStaticMethod(functionDescriptor, context)) {
368                invocationParamBuilder.addNextParameter(AsmTypeConstants.OBJECT_TYPE, false, null);
369            }
370    
371            for (JvmMethodParameterSignature param : valueParameters) {
372                if (param.getKind() == JvmMethodParameterKind.VALUE) {
373                    break;
374                }
375                invocationParamBuilder.addNextParameter(param.getAsmType(), false, null);
376            }
377    
378            List<ParameterInfo> infos = invocationParamBuilder.listNotCaptured();
379            putParameterOnStack(infos.toArray(new ParameterInfo[infos.size()]));
380        }
381    
382        public void leaveTemps() {
383            FrameMap frameMap = codegen.getFrameMap();
384            List<ParameterInfo> infos = invocationParamBuilder.listAllParams();
385            for (ListIterator<? extends ParameterInfo> iterator = infos.listIterator(infos.size()); iterator.hasPrevious(); ) {
386                ParameterInfo param = iterator.previous();
387                if (!param.isSkippedOrRemapped()) {
388                    frameMap.leaveTemp(param.type);
389                }
390            }
391        }
392    
393        public static boolean isInliningClosure(JetExpression expression, ValueParameterDescriptor valueParameterDescriptora) {
394            //TODO deparenthisise typed
395            JetExpression deparenthesize = JetPsiUtil.deparenthesize(expression);
396            return deparenthesize instanceof JetFunctionLiteralExpression &&
397                   !InlineUtil.hasNoinlineAnnotation(valueParameterDescriptora);
398        }
399    
400        public void rememberClosure(JetExpression expression, Type type) {
401            JetFunctionLiteralExpression lambda = (JetFunctionLiteralExpression) JetPsiUtil.deparenthesize(expression);
402            assert lambda != null : "Couldn't find lambda in " + expression.getText();
403    
404            String labelNameIfPresent = null;
405            PsiElement parent = lambda.getParent();
406            if (parent instanceof JetLabeledExpression) {
407                labelNameIfPresent = ((JetLabeledExpression) parent).getLabelName();
408            }
409            LambdaInfo info = new LambdaInfo(lambda, typeMapper, labelNameIfPresent);
410    
411            ParameterInfo closureInfo = invocationParamBuilder.addNextParameter(type, true, null);
412            closureInfo.setLambda(info);
413            expressionMap.put(closureInfo.getIndex(), info);
414        }
415    
416        private void putClosureParametersOnStack() {
417            for (LambdaInfo next : expressionMap.values()) {
418                activeLambda = next;
419                codegen.pushClosureOnStack(next.closure, false, this);
420            }
421            activeLambda = null;
422        }
423    
424        public static CodegenContext getContext(DeclarationDescriptor descriptor, GenerationState state) {
425            if (descriptor instanceof PackageFragmentDescriptor) {
426                return new PackageContext((PackageFragmentDescriptor) descriptor, null, null);
427            }
428    
429            CodegenContext parent = getContext(descriptor.getContainingDeclaration(), state);
430    
431            if (descriptor instanceof ClassDescriptor) {
432                OwnerKind kind = DescriptorUtils.isTrait(descriptor) ? OwnerKind.TRAIT_IMPL : OwnerKind.IMPLEMENTATION;
433                return parent.intoClass((ClassDescriptor) descriptor, kind, state);
434            }
435            else if (descriptor instanceof FunctionDescriptor) {
436                return parent.intoFunction((FunctionDescriptor) descriptor);
437            }
438    
439            throw new IllegalStateException("Couldn't build context for " + descriptorName(descriptor));
440        }
441    
442        private static boolean isStaticMethod(FunctionDescriptor functionDescriptor, MethodContext context) {
443            return (getMethodAsmFlags(functionDescriptor, context.getContextKind()) & Opcodes.ACC_STATIC) != 0;
444        }
445    
446        @NotNull
447        public static String getNodeText(@Nullable MethodNode node) {
448            if (node == null) {
449                return "Not generated";
450            }
451            Textifier p = new Textifier();
452            node.accept(new TraceMethodVisitor(p));
453            StringWriter sw = new StringWriter();
454            p.print(new PrintWriter(sw));
455            sw.flush();
456            return node.name + " " + node.desc + ": \n " + sw.getBuffer().toString();
457        }
458    
459        private static String descriptorName(DeclarationDescriptor descriptor) {
460            return DescriptorRenderer.SHORT_NAMES_IN_TYPES.render(descriptor);
461        }
462    
463        @Override
464        public void genValueAndPut(
465                @NotNull ValueParameterDescriptor valueParameterDescriptor,
466                @NotNull JetExpression argumentExpression,
467                @NotNull Type parameterType
468        ) {
469            //TODO deparenthisise
470            if (isInliningClosure(argumentExpression, valueParameterDescriptor)) {
471                rememberClosure(argumentExpression, parameterType);
472            } else {
473                StackValue value = codegen.gen(argumentExpression);
474                putValueIfNeeded(valueParameterDescriptor, parameterType, value);
475            }
476        }
477    
478        @Override
479        public void putValueIfNeeded(@Nullable ValueParameterDescriptor valueParameterDescriptor, @NotNull Type parameterType, @NotNull StackValue value) {
480            if (shouldPutValue(parameterType, value, valueParameterDescriptor)) {
481                value.put(parameterType, codegen.v);
482            }
483            afterParameterPut(parameterType, value, valueParameterDescriptor);
484        }
485    
486        @Override
487        public void putCapturedValueOnStack(
488                @NotNull StackValue stackValue, @NotNull Type valueType, int paramIndex
489        ) {
490            if (shouldPutValue(stackValue.type, stackValue, null)) {
491                stackValue.put(stackValue.type, codegen.v);
492            }
493            putCapturedInLocal(stackValue.type, stackValue, null, paramIndex);
494        }
495    
496    
497        public void generateAndInsertFinallyBlocks(MethodNode intoNode, List<MethodInliner.FinallyBlockInfo> insertPoints) {
498            if (!codegen.hasFinallyBlocks()) return;
499    
500            for (MethodInliner.FinallyBlockInfo insertPoint : insertPoints) {
501                MethodNode finallyNode = InlineCodegenUtil.createEmptyMethodNode();
502                ExpressionCodegen finallyCodegen =
503                        new ExpressionCodegen(finallyNode, codegen.getFrameMap(), codegen.getReturnType(),
504                                              codegen.getContext(), codegen.getState(), codegen.getParentCodegen());
505                finallyCodegen.addBlockStackElementsForNonLocalReturns(codegen.getBlockStackElements());
506    
507                finallyCodegen.generateFinallyBlocksIfNeeded(insertPoint.returnType);
508    
509                InlineCodegenUtil.insertNodeBefore(finallyNode, intoNode, insertPoint.beforeIns);
510            }
511        }
512    
513    }