001    /*
002    * Copyright 2010-2013 JetBrains s.r.o.
003    *
004    * Licensed under the Apache License, Version 2.0 (the "License");
005    * you may not use this file except in compliance with the License.
006    * You may obtain a copy of the License at
007    *
008    * http://www.apache.org/licenses/LICENSE-2.0
009    *
010    * Unless required by applicable law or agreed to in writing, software
011    * distributed under the License is distributed on an "AS IS" BASIS,
012    * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013    * See the License for the specific language governing permissions and
014    * limitations under the License.
015    */
016    
017    package org.jetbrains.jet.codegen.inline;
018    
019    import com.intellij.openapi.vfs.VirtualFile;
020    import com.intellij.psi.PsiElement;
021    import org.jetbrains.annotations.NotNull;
022    import org.jetbrains.annotations.Nullable;
023    import org.jetbrains.jet.codegen.*;
024    import org.jetbrains.jet.codegen.context.CodegenContext;
025    import org.jetbrains.jet.codegen.context.MethodContext;
026    import org.jetbrains.jet.codegen.context.PackageContext;
027    import org.jetbrains.jet.lang.resolve.java.jvmSignature.JvmMethodParameterKind;
028    import org.jetbrains.jet.lang.resolve.java.jvmSignature.JvmMethodParameterSignature;
029    import org.jetbrains.jet.lang.resolve.java.jvmSignature.JvmMethodSignature;
030    import org.jetbrains.jet.codegen.state.GenerationState;
031    import org.jetbrains.jet.codegen.state.JetTypeMapper;
032    import org.jetbrains.jet.descriptors.serialization.descriptors.DeserializedSimpleFunctionDescriptor;
033    import org.jetbrains.jet.lang.descriptors.*;
034    import org.jetbrains.jet.lang.descriptors.impl.AnonymousFunctionDescriptor;
035    import org.jetbrains.jet.lang.psi.*;
036    import org.jetbrains.jet.lang.resolve.BindingContext;
037    import org.jetbrains.jet.lang.resolve.BindingContextUtils;
038    import org.jetbrains.jet.lang.resolve.DescriptorUtils;
039    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCall;
040    import org.jetbrains.jet.lang.resolve.java.AsmTypeConstants;
041    import org.jetbrains.jet.lang.types.lang.InlineStrategy;
042    import org.jetbrains.jet.lang.types.lang.InlineUtil;
043    import org.jetbrains.jet.renderer.DescriptorRenderer;
044    import org.jetbrains.org.objectweb.asm.MethodVisitor;
045    import org.jetbrains.org.objectweb.asm.Opcodes;
046    import org.jetbrains.org.objectweb.asm.Type;
047    import org.jetbrains.org.objectweb.asm.commons.Method;
048    import org.jetbrains.org.objectweb.asm.tree.MethodNode;
049    import org.jetbrains.org.objectweb.asm.util.Textifier;
050    import org.jetbrains.org.objectweb.asm.util.TraceMethodVisitor;
051    
052    import java.io.IOException;
053    import java.io.PrintWriter;
054    import java.io.StringWriter;
055    import java.util.*;
056    
057    import static org.jetbrains.jet.codegen.AsmUtil.getMethodAsmFlags;
058    import static org.jetbrains.jet.codegen.AsmUtil.isPrimitive;
059    import static org.jetbrains.jet.codegen.AsmUtil.isStatic;
060    
061    public class InlineCodegen implements CallGenerator {
062        private final GenerationState state;
063        private final JetTypeMapper typeMapper;
064        private final BindingContext bindingContext;
065    
066        private final SimpleFunctionDescriptor functionDescriptor;
067        private final JvmMethodSignature jvmSignature;
068        private final JetElement callElement;
069        private final MethodContext context;
070        private final ExpressionCodegen codegen;
071        private final FrameMap originalFunctionFrame;
072        private final boolean asFunctionInline;
073        private final int initialFrameSize;
074        private final boolean isSameModule;
075    
076        protected final List<ParameterInfo> actualParameters = new ArrayList<ParameterInfo>();
077        protected final Map<Integer, LambdaInfo> expressionMap = new HashMap<Integer, LambdaInfo>();
078    
079        private LambdaInfo activeLambda;
080    
081        public InlineCodegen(
082                @NotNull ExpressionCodegen codegen,
083                @NotNull GenerationState state,
084                @NotNull SimpleFunctionDescriptor functionDescriptor,
085                @NotNull JetElement callElement
086        ) {
087            assert functionDescriptor.getInlineStrategy().isInline() : "InlineCodegen could inline only inline function but " + functionDescriptor;
088    
089            this.state = state;
090            this.typeMapper = state.getTypeMapper();
091            this.codegen = codegen;
092            this.callElement = callElement;
093            this.functionDescriptor = functionDescriptor.getOriginal();
094            bindingContext = codegen.getBindingContext();
095            initialFrameSize = codegen.getFrameMap().getCurrentSize();
096    
097            context = (MethodContext) getContext(functionDescriptor, state);
098            originalFunctionFrame = context.prepareFrame(typeMapper);
099            jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
100    
101            InlineStrategy inlineStrategy =
102                    codegen.getContext().isInlineFunction() ? InlineStrategy.IN_PLACE : functionDescriptor.getInlineStrategy();
103            this.asFunctionInline = false;
104    
105            isSameModule = !(functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) /*not compiled library*/ &&
106                           JvmCodegenUtil.isCallInsideSameModuleAsDeclared(functionDescriptor, codegen.getContext());
107        }
108    
109        @Override
110        public void genCallWithoutAssertions(
111                @NotNull CallableMethod callableMethod, @NotNull ExpressionCodegen codegen
112        ) {
113            genCall(callableMethod, null, false, codegen);
114        }
115    
116        @Override
117        public void genCall(@NotNull CallableMethod callableMethod, @Nullable ResolvedCall<?> resolvedCall, boolean callDefault, @NotNull ExpressionCodegen codegen) {
118            MethodNode node = null;
119    
120            try {
121                node = createMethodNode(callDefault);
122                endCall(inlineCall(node));
123            }
124            catch (CompilationException e) {
125                throw e;
126            }
127            catch (Exception e) {
128                boolean generateNodeText = !(e instanceof InlineException);
129                PsiElement element = BindingContextUtils.descriptorToDeclaration(bindingContext, this.codegen.getContext().getContextDescriptor());
130                throw new CompilationException("Couldn't inline method call '" +
131                                           functionDescriptor.getName() +
132                                           "' into \n" + (element != null ? element.getText() : "null psi element " + this.codegen.getContext().getContextDescriptor()) +
133                                           (generateNodeText ? ("\ncause: " + getNodeText(node)) : ""),
134                                           e, callElement);
135            }
136    
137    
138        }
139    
140        private void endCall(@NotNull InlineResult result) {
141            leaveTemps();
142    
143            state.getFactory().removeInlinedClasses(result.getClassesToRemove());
144        }
145    
146        @NotNull
147        private MethodNode createMethodNode(boolean callDefault) throws ClassNotFoundException, IOException {
148            JvmMethodSignature jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
149    
150            Method asmMethod;
151            if (callDefault) {
152                asmMethod = typeMapper.mapDefaultMethod(functionDescriptor, context.getContextKind(), context);
153            }
154            else {
155                asmMethod = jvmSignature.getAsmMethod();
156            }
157    
158            MethodNode node;
159            if (functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) {
160                VirtualFile file = InlineCodegenUtil.getVirtualFileForCallable((DeserializedSimpleFunctionDescriptor) functionDescriptor, state);
161                node = InlineCodegenUtil.getMethodNode(file.getInputStream(), asmMethod.getName(), asmMethod.getDescriptor());
162    
163                if (node == null) {
164                    throw new RuntimeException("Couldn't obtain compiled function body for " + descriptorName(functionDescriptor));
165                }
166            }
167            else {
168                PsiElement element = BindingContextUtils.descriptorToDeclaration(bindingContext, functionDescriptor);
169    
170                if (element == null) {
171                    throw new RuntimeException("Couldn't find declaration for function " + descriptorName(functionDescriptor));
172                }
173    
174                node = new MethodNode(InlineCodegenUtil.API,
175                                               getMethodAsmFlags(functionDescriptor, context.getContextKind()) | (callDefault ? Opcodes.ACC_STATIC : 0),
176                                               asmMethod.getName(),
177                                               asmMethod.getDescriptor(),
178                                               jvmSignature.getGenericsSignature(),
179                                               null);
180    
181                //for maxLocals calculation
182                MethodVisitor maxCalcAdapter = InlineCodegenUtil.wrapWithMaxLocalCalc(node);
183                MethodContext methodContext = context.getParentContext().intoFunction(functionDescriptor);
184                MemberCodegen<?> parentCodegen = codegen.getParentCodegen();
185                if (callDefault) {
186                    boolean isStatic = isStatic(codegen.getContext().getContextKind());
187                    FunctionCodegen.generateDefaultImplBody(
188                            methodContext, jvmSignature, functionDescriptor, isStatic, maxCalcAdapter, DefaultParameterValueLoader.DEFAULT,
189                            (JetNamedFunction) element, parentCodegen, state
190                    );
191                }
192                else {
193                    FunctionCodegen.generateMethodBody(
194                            maxCalcAdapter, functionDescriptor, methodContext, jvmSignature,
195                            new FunctionGenerationStrategy.FunctionDefault(state, functionDescriptor, (JetDeclarationWithBody) element),
196                            parentCodegen
197                    );
198                }
199                maxCalcAdapter.visitMaxs(-1, -1);
200                maxCalcAdapter.visitEnd();
201            }
202            return node;
203        }
204    
205        private InlineResult inlineCall(MethodNode node) {
206            generateClosuresBodies();
207    
208            List<ParameterInfo> realParams = new ArrayList<ParameterInfo>(actualParameters);
209    
210            putClosureParametersOnStack();
211    
212            List<CapturedParamInfo> captured = getAllCaptured();
213    
214            Parameters parameters = new Parameters(realParams, Parameters.shiftAndAddStubs(captured, realParams.size()));
215    
216            InliningContext info = new RootInliningContext(expressionMap,
217                                                           state,
218                                                           codegen.getInlineNameGenerator()
219                                                                   .subGenerator(functionDescriptor.getName().asString()),
220                                                           codegen.getContext(),
221                                                           callElement,
222                                                           codegen.getParentCodegen().getClassName());
223    
224            MethodInliner inliner = new MethodInliner(node, parameters, info, new FieldRemapper(null, null, parameters), isSameModule, "Method inlining " + callElement.getText()); //with captured
225    
226            LocalVarRemapper remapper = new LocalVarRemapper(parameters, initialFrameSize);
227    
228            return inliner.doInline(codegen.v, remapper);
229        }
230    
231        private void generateClosuresBodies() {
232            for (LambdaInfo info : expressionMap.values()) {
233                info.setNode(generateLambdaBody(info));
234            }
235        }
236    
237        private MethodNode generateLambdaBody(LambdaInfo info) {
238            JetFunctionLiteral declaration = info.getFunctionLiteral();
239            FunctionDescriptor descriptor = info.getFunctionDescriptor();
240    
241            MethodContext parentContext = codegen.getContext();
242    
243            MethodContext context = parentContext.intoClosure(descriptor, codegen, typeMapper).intoInlinedLambda(descriptor);
244    
245            JvmMethodSignature jvmMethodSignature = typeMapper.mapSignature(descriptor);
246            Method asmMethod = jvmMethodSignature.getAsmMethod();
247            MethodNode methodNode = new MethodNode(InlineCodegenUtil.API, getMethodAsmFlags(descriptor, context.getContextKind()), asmMethod.getName(), asmMethod.getDescriptor(), jvmMethodSignature.getGenericsSignature(), null);
248    
249            MethodVisitor adapter = InlineCodegenUtil.wrapWithMaxLocalCalc(methodNode);
250    
251            FunctionCodegen.generateMethodBody(adapter, descriptor, context, jvmMethodSignature, new FunctionGenerationStrategy.FunctionDefault(state, descriptor, declaration), codegen.getParentCodegen());
252            adapter.visitMaxs(-1, -1);
253    
254            return methodNode;
255        }
256    
257    
258    
259        @Override
260        public void afterParameterPut(@NotNull Type type, @Nullable StackValue stackValue, @Nullable ValueParameterDescriptor valueParameterDescriptor) {
261            putCapturedInLocal(type, stackValue, valueParameterDescriptor, -1);
262        }
263    
264        private void putCapturedInLocal(
265                @NotNull Type type, @Nullable StackValue stackValue, @Nullable ValueParameterDescriptor valueParameterDescriptor, int capturedParamIndex
266        ) {
267            if (!asFunctionInline && Type.VOID_TYPE != type) {
268                //TODO remap only inlinable closure => otherwise we could get a lot of problem
269                boolean couldBeRemapped = !shouldPutValue(type, stackValue, valueParameterDescriptor);
270                StackValue remappedIndex = couldBeRemapped ? stackValue : null;
271    
272                ParameterInfo info = new ParameterInfo(type, false, couldBeRemapped ? -1 : codegen.getFrameMap().enterTemp(type), remappedIndex);
273    
274                if (capturedParamIndex >= 0 && couldBeRemapped) {
275                    CapturedParamInfo capturedParamInfo = activeLambda.getCapturedVars().get(capturedParamIndex);
276                    capturedParamInfo.setRemapValue(remappedIndex != null ? remappedIndex : StackValue.local(info.getIndex(), info.getType()));
277                }
278    
279                doWithParameter(info);
280            }
281        }
282    
283        /*descriptor is null for captured vars*/
284        public boolean shouldPutValue(
285                @NotNull Type type,
286                @Nullable StackValue stackValue,
287                @Nullable ValueParameterDescriptor descriptor
288        ) {
289    
290            if (stackValue == null) {
291                //default or vararg
292                return true;
293            }
294    
295            //remap only inline functions (and maybe non primitives)
296            //TODO - clean asserion and remapping logic
297            if (isPrimitive(type) != isPrimitive(stackValue.type)) {
298                //don't remap boxing/unboxing primitives - lost identity and perfomance
299                return true;
300            }
301    
302            if (stackValue instanceof StackValue.Local) {
303                return false;
304            }
305    
306            if (stackValue instanceof StackValue.Composed) {
307                //see: Method.isSpecialStackValue: go through aload 0
308                if (codegen.getContext().isInliningLambda() && codegen.getContext().getContextDescriptor() instanceof AnonymousFunctionDescriptor) {
309                    if (descriptor != null && !InlineUtil.hasNoinlineAnnotation(descriptor)) {
310                        //TODO: check type of context
311                        return false;
312                    }
313                }
314            }
315            return true;
316        }
317    
318        private void doWithParameter(ParameterInfo info) {
319            recordParamInfo(info, true);
320            putParameterOnStack(info);
321        }
322    
323        private int recordParamInfo(ParameterInfo info, boolean addToFrame) {
324            Type type = info.type;
325            actualParameters.add(info);
326            if (info.getType().getSize() == 2) {
327                actualParameters.add(ParameterInfo.STUB);
328            }
329            if (addToFrame) {
330                return originalFunctionFrame.enterTemp(type);
331            }
332            return -1;
333        }
334    
335        private void putParameterOnStack(ParameterInfo info) {
336            if (!info.isSkippedOrRemapped()) {
337                int index = info.getIndex();
338                Type type = info.type;
339                StackValue.local(index, type).store(type, codegen.v);
340            }
341        }
342    
343        @Override
344        public void putHiddenParams() {
345            List<JvmMethodParameterSignature> types = jvmSignature.getValueParameters();
346    
347            if (!isStaticMethod(functionDescriptor, context)) {
348                Type type = AsmTypeConstants.OBJECT_TYPE;
349                ParameterInfo info = new ParameterInfo(type, false, codegen.getFrameMap().enterTemp(type), -1);
350                recordParamInfo(info, false);
351            }
352    
353            for (JvmMethodParameterSignature param : types) {
354                if (param.getKind() == JvmMethodParameterKind.VALUE) {
355                    break;
356                }
357                Type type = param.getAsmType();
358                ParameterInfo info = new ParameterInfo(type, false, codegen.getFrameMap().enterTemp(type), -1);
359                recordParamInfo(info, false);
360            }
361    
362            for (ListIterator<? extends ParameterInfo> iterator = actualParameters.listIterator(actualParameters.size()); iterator.hasPrevious(); ) {
363                ParameterInfo param = iterator.previous();
364                putParameterOnStack(param);
365            }
366        }
367    
368        public void leaveTemps() {
369            FrameMap frameMap = codegen.getFrameMap();
370            for (ListIterator<? extends ParameterInfo> iterator = actualParameters.listIterator(actualParameters.size()); iterator.hasPrevious(); ) {
371                ParameterInfo param = iterator.previous();
372                if (!param.isSkippedOrRemapped()) {
373                    frameMap.leaveTemp(param.type);
374                }
375            }
376        }
377    
378        public static boolean isInliningClosure(JetExpression expression, ValueParameterDescriptor valueParameterDescriptora) {
379            //TODO deparenthisise
380            return expression instanceof JetFunctionLiteralExpression &&
381                   !InlineUtil.hasNoinlineAnnotation(valueParameterDescriptora);
382        }
383    
384        public void rememberClosure(JetFunctionLiteralExpression expression, Type type) {
385            ParameterInfo closureInfo = new ParameterInfo(type, true, -1, -1);
386            int index = recordParamInfo(closureInfo, true);
387    
388            LambdaInfo info = new LambdaInfo(expression, typeMapper);
389            expressionMap.put(index, info);
390    
391            closureInfo.setLambda(info);
392        }
393    
394        private void putClosureParametersOnStack() {
395            //TODO: SORT
396            int currentSize = actualParameters.size();
397            for (LambdaInfo next : expressionMap.values()) {
398                if (next.closure != null) {
399                    activeLambda = next;
400                    next.setParamOffset(currentSize);
401                    codegen.pushClosureOnStack(next.closure, false, this);
402                    currentSize += next.getCapturedVarsSize();
403                }
404            }
405            activeLambda = null;
406        }
407    
408        private List<CapturedParamInfo> getAllCaptured() {
409            //TODO: SORT
410            List<CapturedParamInfo> result = new ArrayList<CapturedParamInfo>();
411            for (LambdaInfo next : expressionMap.values()) {
412                if (next.closure != null) {
413                    result.addAll(next.getCapturedVars());
414                }
415            }
416            return result;
417        }
418    
419        public static CodegenContext getContext(DeclarationDescriptor descriptor, GenerationState state) {
420            if (descriptor instanceof PackageFragmentDescriptor) {
421                return new PackageContext((PackageFragmentDescriptor) descriptor, null, null);
422            }
423    
424            CodegenContext parent = getContext(descriptor.getContainingDeclaration(), state);
425    
426            if (descriptor instanceof ClassDescriptor) {
427                OwnerKind kind = DescriptorUtils.isTrait(descriptor) ? OwnerKind.TRAIT_IMPL : OwnerKind.IMPLEMENTATION;
428                return parent.intoClass((ClassDescriptor) descriptor, kind, state);
429            }
430            else if (descriptor instanceof FunctionDescriptor) {
431                return parent.intoFunction((FunctionDescriptor) descriptor);
432            }
433    
434            throw new IllegalStateException("Couldn't build context for " + descriptorName(descriptor));
435        }
436    
437        private static boolean isStaticMethod(FunctionDescriptor functionDescriptor, MethodContext context) {
438            return (getMethodAsmFlags(functionDescriptor, context.getContextKind()) & Opcodes.ACC_STATIC) != 0;
439        }
440    
441        @NotNull
442        public static String getNodeText(@Nullable MethodNode node) {
443            if (node == null) {
444                return "Not generated";
445            }
446            Textifier p = new Textifier();
447            node.accept(new TraceMethodVisitor(p));
448            StringWriter sw = new StringWriter();
449            p.print(new PrintWriter(sw));
450            sw.flush();
451            return node.name + " " + node.desc + ": \n " + sw.getBuffer().toString();
452        }
453    
454        private static String descriptorName(DeclarationDescriptor descriptor) {
455            return DescriptorRenderer.SHORT_NAMES_IN_TYPES.render(descriptor);
456        }
457    
458        @Override
459        public void genValueAndPut(
460                @NotNull ValueParameterDescriptor valueParameterDescriptor,
461                @NotNull JetExpression argumentExpression,
462                @NotNull Type parameterType
463        ) {
464            //TODO deparenthisise
465            if (isInliningClosure(argumentExpression, valueParameterDescriptor)) {
466                rememberClosure((JetFunctionLiteralExpression) argumentExpression, parameterType);
467            } else {
468                StackValue value = codegen.gen(argumentExpression);
469                putValueIfNeeded(valueParameterDescriptor, parameterType, value);
470            }
471        }
472    
473        @Override
474        public void putValueIfNeeded(@Nullable ValueParameterDescriptor valueParameterDescriptor, @NotNull Type parameterType, @NotNull StackValue value) {
475            if (shouldPutValue(parameterType, value, valueParameterDescriptor)) {
476                value.put(parameterType, codegen.v);
477            }
478            afterParameterPut(parameterType, value, valueParameterDescriptor);
479        }
480    
481        @Override
482        public void putCapturedValueOnStack(
483                @NotNull StackValue stackValue, @NotNull Type valueType, int paramIndex
484        ) {
485            if (shouldPutValue(stackValue.type, stackValue, null)) {
486                stackValue.put(stackValue.type, codegen.v);
487            }
488            putCapturedInLocal(stackValue.type, stackValue, null, paramIndex);
489        }
490    }