001    /*
002    * Copyright 2010-2013 JetBrains s.r.o.
003    *
004    * Licensed under the Apache License, Version 2.0 (the "License");
005    * you may not use this file except in compliance with the License.
006    * You may obtain a copy of the License at
007    *
008    * http://www.apache.org/licenses/LICENSE-2.0
009    *
010    * Unless required by applicable law or agreed to in writing, software
011    * distributed under the License is distributed on an "AS IS" BASIS,
012    * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013    * See the License for the specific language governing permissions and
014    * limitations under the License.
015    */
016    
017    package org.jetbrains.jet.codegen.inline;
018    
019    import com.intellij.openapi.vfs.VirtualFile;
020    import com.intellij.psi.PsiElement;
021    import org.jetbrains.annotations.NotNull;
022    import org.jetbrains.annotations.Nullable;
023    import org.jetbrains.jet.codegen.*;
024    import org.jetbrains.jet.codegen.context.CodegenContext;
025    import org.jetbrains.jet.codegen.context.MethodContext;
026    import org.jetbrains.jet.codegen.context.PackageContext;
027    import org.jetbrains.jet.codegen.signature.JvmMethodParameterKind;
028    import org.jetbrains.jet.codegen.signature.JvmMethodParameterSignature;
029    import org.jetbrains.jet.codegen.signature.JvmMethodSignature;
030    import org.jetbrains.jet.codegen.state.GenerationState;
031    import org.jetbrains.jet.codegen.state.JetTypeMapper;
032    import org.jetbrains.jet.descriptors.serialization.descriptors.DeserializedSimpleFunctionDescriptor;
033    import org.jetbrains.jet.lang.descriptors.*;
034    import org.jetbrains.jet.lang.descriptors.impl.AnonymousFunctionDescriptor;
035    import org.jetbrains.jet.lang.psi.*;
036    import org.jetbrains.jet.lang.resolve.BindingContext;
037    import org.jetbrains.jet.lang.resolve.BindingContextUtils;
038    import org.jetbrains.jet.lang.resolve.DescriptorUtils;
039    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCall;
040    import org.jetbrains.jet.lang.resolve.java.AsmTypeConstants;
041    import org.jetbrains.jet.lang.types.lang.InlineStrategy;
042    import org.jetbrains.jet.lang.types.lang.InlineUtil;
043    import org.jetbrains.jet.renderer.DescriptorRenderer;
044    import org.jetbrains.org.objectweb.asm.MethodVisitor;
045    import org.jetbrains.org.objectweb.asm.Opcodes;
046    import org.jetbrains.org.objectweb.asm.Type;
047    import org.jetbrains.org.objectweb.asm.commons.Method;
048    import org.jetbrains.org.objectweb.asm.tree.MethodNode;
049    import org.jetbrains.org.objectweb.asm.util.Textifier;
050    import org.jetbrains.org.objectweb.asm.util.TraceMethodVisitor;
051    
052    import java.io.IOException;
053    import java.io.PrintWriter;
054    import java.io.StringWriter;
055    import java.util.*;
056    
057    import static org.jetbrains.jet.codegen.AsmUtil.getMethodAsmFlags;
058    import static org.jetbrains.jet.codegen.AsmUtil.isPrimitive;
059    
060    public class InlineCodegen implements CallGenerator {
061        private final GenerationState state;
062        private final JetTypeMapper typeMapper;
063        private final BindingContext bindingContext;
064    
065        private final SimpleFunctionDescriptor functionDescriptor;
066        private final JvmMethodSignature jvmSignature;
067        private final Call call;
068        private final MethodContext context;
069        private final ExpressionCodegen codegen;
070        private final FrameMap originalFunctionFrame;
071        private final boolean asFunctionInline;
072        private final int initialFrameSize;
073        private final boolean isSameModule;
074    
075        protected final List<ParameterInfo> actualParameters = new ArrayList<ParameterInfo>();
076        protected final Map<Integer, LambdaInfo> expressionMap = new HashMap<Integer, LambdaInfo>();
077    
078        private LambdaInfo activeLambda;
079    
080        public InlineCodegen(
081                @NotNull ExpressionCodegen codegen,
082                @NotNull GenerationState state,
083                @NotNull SimpleFunctionDescriptor functionDescriptor,
084                @NotNull Call call
085        ) {
086            assert functionDescriptor.getInlineStrategy().isInline() : "InlineCodegen could inline only inline function but " + functionDescriptor;
087    
088            this.state = state;
089            this.typeMapper = state.getTypeMapper();
090            this.codegen = codegen;
091            this.call = call;
092            this.functionDescriptor = functionDescriptor.getOriginal();
093            bindingContext = codegen.getBindingContext();
094            initialFrameSize = codegen.getFrameMap().getCurrentSize();
095    
096            context = (MethodContext) getContext(functionDescriptor, state);
097            originalFunctionFrame = context.prepareFrame(typeMapper);
098            jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
099    
100            InlineStrategy inlineStrategy =
101                    codegen.getContext().isInlineFunction() ? InlineStrategy.IN_PLACE : functionDescriptor.getInlineStrategy();
102            this.asFunctionInline = false;
103    
104            isSameModule = !(functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) /*not compiled library*/ &&
105                           JvmCodegenUtil.isCallInsideSameModuleAsDeclared(functionDescriptor, codegen.getContext());
106        }
107    
108    
109        @Override
110        public void genCall(CallableMethod callableMethod, ResolvedCall<?> resolvedCall, int mask, ExpressionCodegen codegen) {
111            assert mask == 0 : "Default method invocation couldn't be inlined " + resolvedCall;
112    
113            MethodNode node = null;
114    
115            try {
116                node = createMethodNode(callableMethod);
117                endCall(inlineCall(node));
118            }
119            catch (CompilationException e) {
120                throw e;
121            }
122            catch (Exception e) {
123                boolean generateNodeText = !(e instanceof InlineException);
124                PsiElement element = BindingContextUtils.descriptorToDeclaration(bindingContext, this.codegen.getContext().getContextDescriptor());
125                throw new CompilationException("Couldn't inline method call '" +
126                                           functionDescriptor.getName() +
127                                           "' into \n" + (element != null ? element.getText() : "null psi element " + this.codegen.getContext().getContextDescriptor()) +
128                                           (generateNodeText ? ("\ncause: " + getNodeText(node)) : ""),
129                                           e, call.getCallElement());
130            }
131    
132    
133        }
134    
135        private void endCall(@NotNull InlineResult result) {
136            leaveTemps();
137    
138            state.getFactory().removeInlinedClasses(result.getClassesToRemove());
139        }
140    
141        @NotNull
142        private MethodNode createMethodNode(CallableMethod callableMethod)
143                throws ClassNotFoundException, IOException {
144            MethodNode node;
145            if (functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) {
146                VirtualFile file = InlineCodegenUtil.getVirtualFileForCallable((DeserializedSimpleFunctionDescriptor) functionDescriptor, state);
147                String methodDesc = callableMethod.getAsmMethod().getDescriptor();
148                DeclarationDescriptor parentDescriptor = functionDescriptor.getContainingDeclaration();
149                if (DescriptorUtils.isTrait(parentDescriptor)) {
150                    methodDesc = "(" + typeMapper.mapType((ClassDescriptor) parentDescriptor).getDescriptor() + methodDesc.substring(1);
151                }
152                node = InlineCodegenUtil.getMethodNode(file.getInputStream(), functionDescriptor.getName().asString(), methodDesc);
153    
154                if (node == null) {
155                    throw new RuntimeException("Couldn't obtain compiled function body for " + descriptorName(functionDescriptor));
156                }
157            }
158            else {
159                PsiElement element = BindingContextUtils.descriptorToDeclaration(bindingContext, functionDescriptor);
160    
161                if (element == null) {
162                    throw new RuntimeException("Couldn't find declaration for function " + descriptorName(functionDescriptor));
163                }
164    
165                JvmMethodSignature jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
166                Method asmMethod = jvmSignature.getAsmMethod();
167                node = new MethodNode(InlineCodegenUtil.API,
168                                               getMethodAsmFlags(functionDescriptor, context.getContextKind()),
169                                               asmMethod.getName(),
170                                               asmMethod.getDescriptor(),
171                                               jvmSignature.getGenericsSignature(),
172                                               null);
173    
174                //for maxLocals calculation
175                MethodVisitor adapter = InlineCodegenUtil.wrapWithMaxLocalCalc(node);
176                FunctionCodegen.generateMethodBody(adapter, functionDescriptor, context.getParentContext().intoFunction(functionDescriptor),
177                                                   jvmSignature,
178                                                   new FunctionGenerationStrategy.FunctionDefault(state,
179                                                                                                  functionDescriptor,
180                                                                                                  (JetDeclarationWithBody) element),
181                                                   codegen.getParentCodegen());
182                adapter.visitMaxs(-1, -1);
183                adapter.visitEnd();
184            }
185            return node;
186        }
187    
188        private InlineResult inlineCall(MethodNode node) {
189            generateClosuresBodies();
190    
191            List<ParameterInfo> realParams = new ArrayList<ParameterInfo>(actualParameters);
192    
193            putClosureParametersOnStack();
194    
195            List<CapturedParamInfo> captured = getAllCaptured();
196    
197            Parameters parameters = new Parameters(realParams, Parameters.shiftAndAddStubs(captured, realParams.size()));
198    
199            InliningContext info = new RootInliningContext(expressionMap,
200                                                           state,
201                                                           codegen.getInlineNameGenerator()
202                                                                   .subGenerator(functionDescriptor.getName().asString()),
203                                                           codegen.getContext(),
204                                                           call,
205                                                           codegen.getParentCodegen().getClassName());
206    
207            MethodInliner inliner = new MethodInliner(node, parameters, info, new FieldRemapper(null, null, parameters), isSameModule, "Method inlining " + call.getCallElement().getText()); //with captured
208    
209            LocalVarRemapper remapper = new LocalVarRemapper(parameters, initialFrameSize);
210    
211            return inliner.doInline(codegen.v, remapper);
212        }
213    
214        private void generateClosuresBodies() {
215            for (LambdaInfo info : expressionMap.values()) {
216                info.setNode(generateLambdaBody(info));
217            }
218        }
219    
220        private MethodNode generateLambdaBody(LambdaInfo info) {
221            JetFunctionLiteral declaration = info.getFunctionLiteral();
222            FunctionDescriptor descriptor = info.getFunctionDescriptor();
223    
224            MethodContext parentContext = codegen.getContext();
225    
226            MethodContext context = parentContext.intoClosure(descriptor, codegen, typeMapper).intoInlinedLambda(descriptor);
227    
228            JvmMethodSignature jvmMethodSignature = typeMapper.mapSignature(descriptor);
229            Method asmMethod = jvmMethodSignature.getAsmMethod();
230            MethodNode methodNode = new MethodNode(InlineCodegenUtil.API, getMethodAsmFlags(descriptor, context.getContextKind()), asmMethod.getName(), asmMethod.getDescriptor(), jvmMethodSignature.getGenericsSignature(), null);
231    
232            MethodVisitor adapter = InlineCodegenUtil.wrapWithMaxLocalCalc(methodNode);
233    
234            FunctionCodegen.generateMethodBody(adapter, descriptor, context, jvmMethodSignature, new FunctionGenerationStrategy.FunctionDefault(state, descriptor, declaration), codegen.getParentCodegen());
235            adapter.visitMaxs(-1, -1);
236    
237            return methodNode;
238        }
239    
240    
241    
242        @Override
243        public void afterParameterPut(@NotNull Type type, @Nullable StackValue stackValue, ValueParameterDescriptor valueParameterDescriptor) {
244            putCapturedInLocal(type, stackValue, valueParameterDescriptor, -1);
245        }
246    
247        public void putCapturedInLocal(
248                @NotNull Type type, @Nullable StackValue stackValue, @Nullable ValueParameterDescriptor valueParameterDescriptor, int capturedParamIndex
249        ) {
250            if (!asFunctionInline && Type.VOID_TYPE != type) {
251                //TODO remap only inlinable closure => otherwise we could get a lot of problem
252                boolean couldBeRemapped = !shouldPutValue(type, stackValue, valueParameterDescriptor);
253                StackValue remappedIndex = couldBeRemapped ? stackValue : null;
254    
255                ParameterInfo info = new ParameterInfo(type, false, couldBeRemapped ? -1 : codegen.getFrameMap().enterTemp(type), remappedIndex);
256    
257                if (capturedParamIndex >= 0 && couldBeRemapped) {
258                    CapturedParamInfo capturedParamInfo = activeLambda.getCapturedVars().get(capturedParamIndex);
259                    capturedParamInfo.setRemapValue(remappedIndex != null ? remappedIndex : StackValue.local(info.getIndex(), info.getType()));
260                }
261    
262                doWithParameter(info);
263            }
264        }
265    
266        /*descriptor is null for captured vars*/
267        public boolean shouldPutValue(
268                @NotNull Type type,
269                @Nullable StackValue stackValue,
270                @Nullable ValueParameterDescriptor descriptor
271        ) {
272    
273            if (stackValue == null) {
274                //default or vararg
275                return true;
276            }
277    
278            //remap only inline functions (and maybe non primitives)
279            //TODO - clean asserion and remapping logic
280            if (isPrimitive(type) != isPrimitive(stackValue.type)) {
281                //don't remap boxing/unboxing primitives - lost identity and perfomance
282                return true;
283            }
284    
285            if (stackValue instanceof StackValue.Local) {
286                return false;
287            }
288    
289            if (stackValue instanceof StackValue.Composed) {
290                //see: Method.isSpecialStackValue: go through aload 0
291                if (codegen.getContext().isInliningLambda() && codegen.getContext().getContextDescriptor() instanceof AnonymousFunctionDescriptor) {
292                    if (descriptor != null && !InlineUtil.hasNoinlineAnnotation(descriptor)) {
293                        //TODO: check type of context
294                        return false;
295                    }
296                }
297            }
298            return true;
299        }
300    
301        private void doWithParameter(ParameterInfo info) {
302            recordParamInfo(info, true);
303            putParameterOnStack(info);
304        }
305    
306        private int recordParamInfo(ParameterInfo info, boolean addToFrame) {
307            Type type = info.type;
308            actualParameters.add(info);
309            if (info.getType().getSize() == 2) {
310                actualParameters.add(ParameterInfo.STUB);
311            }
312            if (addToFrame) {
313                return originalFunctionFrame.enterTemp(type);
314            }
315            return -1;
316        }
317    
318        private void putParameterOnStack(ParameterInfo info) {
319            if (!info.isSkippedOrRemapped()) {
320                int index = info.getIndex();
321                Type type = info.type;
322                StackValue.local(index, type).store(type, codegen.v);
323            }
324        }
325    
326        @Override
327        public void putHiddenParams() {
328            List<JvmMethodParameterSignature> types = jvmSignature.getValueParameters();
329    
330            if (!isStaticMethod(functionDescriptor, context)) {
331                Type type = AsmTypeConstants.OBJECT_TYPE;
332                ParameterInfo info = new ParameterInfo(type, false, codegen.getFrameMap().enterTemp(type), -1);
333                recordParamInfo(info, false);
334            }
335    
336            for (JvmMethodParameterSignature param : types) {
337                if (param.getKind() == JvmMethodParameterKind.VALUE) {
338                    break;
339                }
340                Type type = param.getAsmType();
341                ParameterInfo info = new ParameterInfo(type, false, codegen.getFrameMap().enterTemp(type), -1);
342                recordParamInfo(info, false);
343            }
344    
345            for (ListIterator<? extends ParameterInfo> iterator = actualParameters.listIterator(actualParameters.size()); iterator.hasPrevious(); ) {
346                ParameterInfo param = iterator.previous();
347                putParameterOnStack(param);
348            }
349        }
350    
351        public void leaveTemps() {
352            FrameMap frameMap = codegen.getFrameMap();
353            for (ListIterator<? extends ParameterInfo> iterator = actualParameters.listIterator(actualParameters.size()); iterator.hasPrevious(); ) {
354                ParameterInfo param = iterator.previous();
355                if (!param.isSkippedOrRemapped()) {
356                    frameMap.leaveTemp(param.type);
357                }
358            }
359        }
360    
361        public static boolean isInliningClosure(JetExpression expression, ValueParameterDescriptor valueParameterDescriptora) {
362            //TODO deparenthisise
363            return expression instanceof JetFunctionLiteralExpression &&
364                   !InlineUtil.hasNoinlineAnnotation(valueParameterDescriptora);
365        }
366    
367        public void rememberClosure(JetFunctionLiteralExpression expression, Type type) {
368            ParameterInfo closureInfo = new ParameterInfo(type, true, -1, -1);
369            int index = recordParamInfo(closureInfo, true);
370    
371            LambdaInfo info = new LambdaInfo(expression, typeMapper);
372            expressionMap.put(index, info);
373    
374            closureInfo.setLambda(info);
375        }
376    
377        private void putClosureParametersOnStack() {
378            //TODO: SORT
379            int currentSize = actualParameters.size();
380            for (LambdaInfo next : expressionMap.values()) {
381                if (next.closure != null) {
382                    activeLambda = next;
383                    next.setParamOffset(currentSize);
384                    codegen.pushClosureOnStack(next.closure, false, this);
385                    currentSize += next.getCapturedVarsSize();
386                }
387            }
388            activeLambda = null;
389        }
390    
391        private List<CapturedParamInfo> getAllCaptured() {
392            //TODO: SORT
393            List<CapturedParamInfo> result = new ArrayList<CapturedParamInfo>();
394            for (LambdaInfo next : expressionMap.values()) {
395                if (next.closure != null) {
396                    result.addAll(next.getCapturedVars());
397                }
398            }
399            return result;
400        }
401    
402        public static CodegenContext getContext(DeclarationDescriptor descriptor, GenerationState state) {
403            if (descriptor instanceof PackageFragmentDescriptor) {
404                return new PackageContext((PackageFragmentDescriptor) descriptor, null, null);
405            }
406    
407            CodegenContext parent = getContext(descriptor.getContainingDeclaration(), state);
408    
409            if (descriptor instanceof ClassDescriptor) {
410                OwnerKind kind = DescriptorUtils.isTrait(descriptor) ? OwnerKind.TRAIT_IMPL : OwnerKind.IMPLEMENTATION;
411                return parent.intoClass((ClassDescriptor) descriptor, kind, state);
412            }
413            else if (descriptor instanceof FunctionDescriptor) {
414                return parent.intoFunction((FunctionDescriptor) descriptor);
415            }
416    
417            throw new IllegalStateException("Couldn't build context for " + descriptorName(descriptor));
418        }
419    
420        private static boolean isStaticMethod(FunctionDescriptor functionDescriptor, MethodContext context) {
421            return (getMethodAsmFlags(functionDescriptor, context.getContextKind()) & Opcodes.ACC_STATIC) != 0;
422        }
423    
424        @NotNull
425        public static String getNodeText(@Nullable MethodNode node) {
426            if (node == null) {
427                return "Not generated";
428            }
429            Textifier p = new Textifier();
430            node.accept(new TraceMethodVisitor(p));
431            StringWriter sw = new StringWriter();
432            p.print(new PrintWriter(sw));
433            sw.flush();
434            return node.name + " " + node.desc + ": \n " + sw.getBuffer().toString();
435        }
436    
437        private static String descriptorName(DeclarationDescriptor descriptor) {
438            return DescriptorRenderer.SHORT_NAMES_IN_TYPES.render(descriptor);
439        }
440    
441        @Override
442        public void genValueAndPut(
443                @NotNull ValueParameterDescriptor valueParameterDescriptor,
444                @NotNull JetExpression argumentExpression,
445                @NotNull Type parameterType
446        ) {
447            //TODO deparenthisise
448            if (isInliningClosure(argumentExpression, valueParameterDescriptor)) {
449                rememberClosure((JetFunctionLiteralExpression) argumentExpression, parameterType);
450            } else {
451                StackValue value = codegen.gen(argumentExpression);
452                if (shouldPutValue(parameterType, value, valueParameterDescriptor)) {
453                    value.put(parameterType, codegen.v);
454                }
455                afterParameterPut(parameterType, value, valueParameterDescriptor);
456            }
457        }
458    
459        @Override
460        public void putCapturedValueOnStack(
461                @NotNull StackValue stackValue, @NotNull Type valueType, int paramIndex
462        ) {
463            if (shouldPutValue(stackValue.type, stackValue, null)) {
464                stackValue.put(stackValue.type, codegen.v);
465            }
466            putCapturedInLocal(stackValue.type, stackValue, null, paramIndex);
467        }
468    }