001    /*
002    * Copyright 2010-2013 JetBrains s.r.o.
003    *
004    * Licensed under the Apache License, Version 2.0 (the "License");
005    * you may not use this file except in compliance with the License.
006    * You may obtain a copy of the License at
007    *
008    * http://www.apache.org/licenses/LICENSE-2.0
009    *
010    * Unless required by applicable law or agreed to in writing, software
011    * distributed under the License is distributed on an "AS IS" BASIS,
012    * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013    * See the License for the specific language governing permissions and
014    * limitations under the License.
015    */
016    
017    package org.jetbrains.jet.codegen.inline;
018    
019    import com.intellij.openapi.vfs.VirtualFile;
020    import com.intellij.psi.PsiElement;
021    import org.jetbrains.annotations.NotNull;
022    import org.jetbrains.annotations.Nullable;
023    import org.jetbrains.asm4.MethodVisitor;
024    import org.jetbrains.asm4.Opcodes;
025    import org.jetbrains.asm4.Type;
026    import org.jetbrains.asm4.commons.Method;
027    import org.jetbrains.asm4.tree.MethodNode;
028    import org.jetbrains.asm4.util.Textifier;
029    import org.jetbrains.asm4.util.TraceMethodVisitor;
030    import org.jetbrains.jet.codegen.*;
031    import org.jetbrains.jet.codegen.context.CodegenContext;
032    import org.jetbrains.jet.codegen.context.MethodContext;
033    import org.jetbrains.jet.codegen.context.PackageContext;
034    import org.jetbrains.jet.codegen.signature.JvmMethodParameterKind;
035    import org.jetbrains.jet.codegen.signature.JvmMethodParameterSignature;
036    import org.jetbrains.jet.codegen.signature.JvmMethodSignature;
037    import org.jetbrains.jet.codegen.state.GenerationState;
038    import org.jetbrains.jet.codegen.state.JetTypeMapper;
039    import org.jetbrains.jet.descriptors.serialization.descriptors.DeserializedSimpleFunctionDescriptor;
040    import org.jetbrains.jet.lang.descriptors.*;
041    import org.jetbrains.jet.lang.descriptors.impl.AnonymousFunctionDescriptor;
042    import org.jetbrains.jet.lang.psi.*;
043    import org.jetbrains.jet.lang.resolve.BindingContext;
044    import org.jetbrains.jet.lang.resolve.BindingContextUtils;
045    import org.jetbrains.jet.lang.resolve.DescriptorUtils;
046    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCall;
047    import org.jetbrains.jet.lang.resolve.java.AsmTypeConstants;
048    import org.jetbrains.jet.lang.types.lang.InlineStrategy;
049    import org.jetbrains.jet.lang.types.lang.InlineUtil;
050    import org.jetbrains.jet.renderer.DescriptorRenderer;
051    
052    import java.io.IOException;
053    import java.io.PrintWriter;
054    import java.io.StringWriter;
055    import java.util.*;
056    
057    import static org.jetbrains.jet.codegen.AsmUtil.getMethodAsmFlags;
058    import static org.jetbrains.jet.codegen.AsmUtil.isPrimitive;
059    
060    public class InlineCodegen implements ParentCodegenAware, CallGenerator {
061    
062        private final JetTypeMapper typeMapper;
063    
064        private final ExpressionCodegen codegen;
065    
066        private final boolean asFunctionInline;
067    
068        private final GenerationState state;
069    
070        private final Call call;
071    
072        private final SimpleFunctionDescriptor functionDescriptor;
073    
074        private final BindingContext bindingContext;
075    
076        private final MethodContext context;
077    
078        private final FrameMap originalFunctionFrame;
079    
080        private final int initialFrameSize;
081    
082        private final JvmMethodSignature jvmSignature;
083    
084        private final boolean isSameModule;
085    
086        private LambdaInfo activeLambda;
087    
088        protected final List<ParameterInfo> actualParameters = new ArrayList<ParameterInfo>();
089    
090        protected final Map<Integer, LambdaInfo> expressionMap = new HashMap<Integer, LambdaInfo>();
091    
092        public InlineCodegen(
093                @NotNull ExpressionCodegen codegen,
094                @NotNull GenerationState state,
095                @NotNull SimpleFunctionDescriptor functionDescriptor,
096                @NotNull Call call
097        ) {
098            assert functionDescriptor.getInlineStrategy().isInline() : "InlineCodegen could inline only inline function but " + functionDescriptor;
099    
100            this.state = state;
101            this.typeMapper = state.getTypeMapper();
102            this.codegen = codegen;
103            this.call = call;
104            this.functionDescriptor = functionDescriptor.getOriginal();
105            bindingContext = codegen.getBindingContext();
106            initialFrameSize = codegen.getFrameMap().getCurrentSize();
107    
108            context = (MethodContext) getContext(functionDescriptor, state);
109            originalFunctionFrame = context.prepareFrame(typeMapper);
110            jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
111    
112            InlineStrategy inlineStrategy =
113                    codegen.getContext().isInlineFunction() ? InlineStrategy.IN_PLACE : functionDescriptor.getInlineStrategy();
114            this.asFunctionInline = false;
115    
116            isSameModule = !(functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) /*not compiled library*/ &&
117                           CodegenUtil.isCallInsideSameModuleAsDeclared(functionDescriptor, codegen.getContext());
118        }
119    
120    
121        @Override
122        public void genCall(CallableMethod callableMethod, ResolvedCall<?> resolvedCall, int mask, ExpressionCodegen codegen) {
123            assert mask == 0 : "Default method invocation couldn't be inlined " + resolvedCall;
124    
125            MethodNode node = null;
126    
127            try {
128                node = createMethodNode(callableMethod);
129                endCall(inlineCall(node));
130            }
131            catch (CompilationException e) {
132                throw e;
133            }
134            catch (Exception e) {
135                boolean generateNodeText = !(e instanceof InlineException);
136                PsiElement element = BindingContextUtils.descriptorToDeclaration(bindingContext, this.codegen.getContext().getContextDescriptor());
137                throw new CompilationException("Couldn't inline method call '" +
138                                           functionDescriptor.getName() +
139                                           "' into \n" + (element != null ? element.getText() : "null psi element " + this.codegen.getContext().getContextDescriptor()) +
140                                           (generateNodeText ? ("\ncause: " + getNodeText(node)) : ""),
141                                           e, call.getCallElement());
142            }
143    
144    
145        }
146    
147        private void endCall(@NotNull InlineResult result) {
148            leaveTemps();
149    
150            state.getFactory().removeInlinedClasses(result.getClassesToRemove());
151        }
152    
153        @NotNull
154        private MethodNode createMethodNode(CallableMethod callableMethod)
155                throws ClassNotFoundException, IOException {
156            MethodNode node;
157            if (functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) {
158                VirtualFile file = InlineCodegenUtil.getVirtualFileForCallable((DeserializedSimpleFunctionDescriptor) functionDescriptor, state);
159                String methodDesc = callableMethod.getAsmMethod().getDescriptor();
160                DeclarationDescriptor parentDescriptor = functionDescriptor.getContainingDeclaration();
161                if (DescriptorUtils.isTrait(parentDescriptor)) {
162                    methodDesc = "(" + typeMapper.mapType((ClassDescriptor) parentDescriptor).getDescriptor() + methodDesc.substring(1);
163                }
164                node = InlineCodegenUtil.getMethodNode(file.getInputStream(), functionDescriptor.getName().asString(), methodDesc);
165    
166                if (node == null) {
167                    throw new RuntimeException("Couldn't obtain compiled function body for " + descriptorName(functionDescriptor));
168                }
169            }
170            else {
171                PsiElement element = BindingContextUtils.descriptorToDeclaration(bindingContext, functionDescriptor);
172    
173                if (element == null) {
174                    throw new RuntimeException("Couldn't find declaration for function " + descriptorName(functionDescriptor));
175                }
176    
177                JvmMethodSignature jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
178                Method asmMethod = jvmSignature.getAsmMethod();
179                node = new MethodNode(Opcodes.ASM4,
180                                               getMethodAsmFlags(functionDescriptor, context.getContextKind()),
181                                               asmMethod.getName(),
182                                               asmMethod.getDescriptor(),
183                                               jvmSignature.getGenericsSignature(),
184                                               null);
185    
186                //for maxLocals calculation
187                MethodVisitor adapter = InlineCodegenUtil.wrapWithMaxLocalCalc(node);
188                FunctionCodegen.generateMethodBody(adapter, functionDescriptor, context.getParentContext().intoFunction(functionDescriptor),
189                                                   jvmSignature,
190                                                   new FunctionGenerationStrategy.FunctionDefault(state,
191                                                                                                  functionDescriptor,
192                                                                                                  (JetDeclarationWithBody) element),
193                                                   getParentCodegen());
194                adapter.visitMaxs(-1, -1);
195                adapter.visitEnd();
196            }
197            return node;
198        }
199    
200        private InlineResult inlineCall(MethodNode node) {
201            generateClosuresBodies();
202    
203            List<ParameterInfo> realParams = new ArrayList<ParameterInfo>(actualParameters);
204    
205            putClosureParametersOnStack();
206    
207            List<CapturedParamInfo> captured = getAllCaptured();
208    
209            Parameters parameters = new Parameters(realParams, Parameters.shiftAndAddStubs(captured, realParams.size()));
210    
211            InliningContext info =
212                    new InliningContext(expressionMap, null, null, null, state,
213                                     codegen.getInlineNameGenerator().subGenerator(functionDescriptor.getName().asString()),
214                                     codegen.getContext(), call, Collections.<String, String>emptyMap(), false, false);
215    
216            MethodInliner inliner = new MethodInliner(node, parameters, info, new FieldRemapper(null, null, parameters), isSameModule, "Method inlining " + call.getCallElement().getText()); //with captured
217    
218            VarRemapper.ParamRemapper remapper = new VarRemapper.ParamRemapper(parameters, initialFrameSize);
219    
220            return inliner.doInline(codegen.v, remapper, new FieldRemapper(null, null, parameters));
221        }
222    
223        private void generateClosuresBodies() {
224            for (LambdaInfo info : expressionMap.values()) {
225                info.setNode(generateLambdaBody(info));
226            }
227        }
228    
229        private MethodNode generateLambdaBody(LambdaInfo info) {
230            JetFunctionLiteral declaration = info.getFunctionLiteral();
231            FunctionDescriptor descriptor = info.getFunctionDescriptor();
232    
233            MethodContext parentContext = codegen.getContext();
234    
235            MethodContext context = parentContext.intoClosure(descriptor, codegen, typeMapper).intoInlinedLambda(descriptor);
236    
237            JvmMethodSignature jvmMethodSignature = typeMapper.mapSignature(descriptor);
238            Method asmMethod = jvmMethodSignature.getAsmMethod();
239            MethodNode methodNode = new MethodNode(Opcodes.ASM4, getMethodAsmFlags(descriptor, context.getContextKind()), asmMethod.getName(), asmMethod.getDescriptor(), jvmMethodSignature.getGenericsSignature(), null);
240    
241            MethodVisitor adapter = InlineCodegenUtil.wrapWithMaxLocalCalc(methodNode);
242    
243            FunctionCodegen.generateMethodBody(adapter, descriptor, context, jvmMethodSignature, new FunctionGenerationStrategy.FunctionDefault(state, descriptor, declaration) {
244                @Override
245                public boolean generateLocalVarTable() {
246                    return false;
247                }
248            }, codegen.getParentCodegen());
249            adapter.visitMaxs(-1, -1);
250    
251            return methodNode;
252        }
253    
254    
255    
256        @Override
257        public void afterParameterPut(@NotNull Type type, @Nullable StackValue stackValue, ValueParameterDescriptor valueParameterDescriptor) {
258            putCapturedInLocal(type, stackValue, valueParameterDescriptor, -1);
259        }
260    
261        public void putCapturedInLocal(
262                @NotNull Type type, @Nullable StackValue stackValue, @Nullable ValueParameterDescriptor valueParameterDescriptor, int capturedParamIndex
263        ) {
264            if (!asFunctionInline && Type.VOID_TYPE != type) {
265                //TODO remap only inlinable closure => otherwise we could get a lot of problem
266                boolean couldBeRemapped = !shouldPutValue(type, stackValue, valueParameterDescriptor);
267                StackValue remappedIndex = couldBeRemapped ? stackValue : null;
268    
269                ParameterInfo info = new ParameterInfo(type, false, couldBeRemapped ? -1 : codegen.getFrameMap().enterTemp(type), remappedIndex);
270    
271                if (capturedParamIndex >= 0 && couldBeRemapped) {
272                    CapturedParamInfo capturedParamInfo = activeLambda.getCapturedVars().get(capturedParamIndex);
273                    capturedParamInfo.setRemapValue(remappedIndex != null ? remappedIndex : StackValue.local(info.getIndex(), info.getType()));
274                }
275    
276                doWithParameter(info);
277            }
278        }
279    
280        /*descriptor is null for captured vars*/
281        public boolean shouldPutValue(
282                @NotNull Type type,
283                @Nullable StackValue stackValue,
284                @Nullable ValueParameterDescriptor descriptor
285        ) {
286    
287            if (stackValue == null) {
288                //default or vararg
289                return true;
290            }
291    
292            //remap only inline functions (and maybe non primitives)
293            //TODO - clean asserion and remapping logic
294            if (isPrimitive(type) != isPrimitive(stackValue.type)) {
295                //don't remap boxing/unboxing primitives - lost identity and perfomance
296                return true;
297            }
298    
299            if (stackValue instanceof StackValue.Local) {
300                return false;
301            }
302    
303            if (stackValue instanceof StackValue.Composed) {
304                //see: Method.isSpecialStackValue: go through aload 0
305                if (codegen.getContext().isInliningLambda() && codegen.getContext().getContextDescriptor() instanceof AnonymousFunctionDescriptor) {
306                    if (descriptor != null && !InlineUtil.hasNoinlineAnnotation(descriptor)) {
307                        //TODO: check type of context
308                        return false;
309                    }
310                }
311            }
312            return true;
313        }
314    
315        private void doWithParameter(ParameterInfo info) {
316            recordParamInfo(info, true);
317            putParameterOnStack(info);
318        }
319    
320        private int recordParamInfo(ParameterInfo info, boolean addToFrame) {
321            Type type = info.type;
322            actualParameters.add(info);
323            if (info.getType().getSize() == 2) {
324                actualParameters.add(ParameterInfo.STUB);
325            }
326            if (addToFrame) {
327                return originalFunctionFrame.enterTemp(type);
328            }
329            return -1;
330        }
331    
332        private void putParameterOnStack(ParameterInfo info) {
333            if (!info.isSkippedOrRemapped()) {
334                int index = info.getIndex();
335                Type type = info.type;
336                StackValue.local(index, type).store(type, codegen.v);
337            }
338        }
339    
340        @Override
341        public void putHiddenParams() {
342            List<JvmMethodParameterSignature> types = jvmSignature.getValueParameters();
343    
344            if (!isStaticMethod(functionDescriptor, context)) {
345                Type type = AsmTypeConstants.OBJECT_TYPE;
346                ParameterInfo info = new ParameterInfo(type, false, codegen.getFrameMap().enterTemp(type), -1);
347                recordParamInfo(info, false);
348            }
349    
350            for (JvmMethodParameterSignature param : types) {
351                if (param.getKind() == JvmMethodParameterKind.VALUE) {
352                    break;
353                }
354                Type type = param.getAsmType();
355                ParameterInfo info = new ParameterInfo(type, false, codegen.getFrameMap().enterTemp(type), -1);
356                recordParamInfo(info, false);
357            }
358    
359            for (ListIterator<? extends ParameterInfo> iterator = actualParameters.listIterator(actualParameters.size()); iterator.hasPrevious(); ) {
360                ParameterInfo param = iterator.previous();
361                putParameterOnStack(param);
362            }
363        }
364    
365        public void leaveTemps() {
366            FrameMap frameMap = codegen.getFrameMap();
367            for (ListIterator<? extends ParameterInfo> iterator = actualParameters.listIterator(actualParameters.size()); iterator.hasPrevious(); ) {
368                ParameterInfo param = iterator.previous();
369                if (!param.isSkippedOrRemapped()) {
370                    frameMap.leaveTemp(param.type);
371                }
372            }
373        }
374    
375        public static boolean isInliningClosure(JetExpression expression, ValueParameterDescriptor valueParameterDescriptora) {
376            //TODO deparenthisise
377            return expression instanceof JetFunctionLiteralExpression &&
378                   !InlineUtil.hasNoinlineAnnotation(valueParameterDescriptora);
379        }
380    
381        public void rememberClosure(JetFunctionLiteralExpression expression, Type type) {
382            ParameterInfo closureInfo = new ParameterInfo(type, true, -1, -1);
383            int index = recordParamInfo(closureInfo, true);
384    
385            LambdaInfo info = new LambdaInfo(expression, typeMapper);
386            expressionMap.put(index, info);
387    
388            closureInfo.setLambda(info);
389        }
390    
391        private void putClosureParametersOnStack() {
392            //TODO: SORT
393            int currentSize = actualParameters.size();
394            for (LambdaInfo next : expressionMap.values()) {
395                if (next.closure != null) {
396                    activeLambda = next;
397                    next.setParamOffset(currentSize);
398                    codegen.pushClosureOnStack(next.closure, false, this);
399                    currentSize += next.getCapturedVarsSize();
400                }
401            }
402            activeLambda = null;
403        }
404    
405        private List<CapturedParamInfo> getAllCaptured() {
406            //TODO: SORT
407            List<CapturedParamInfo> result = new ArrayList<CapturedParamInfo>();
408            for (LambdaInfo next : expressionMap.values()) {
409                if (next.closure != null) {
410                    result.addAll(next.getCapturedVars());
411                }
412            }
413            return result;
414        }
415    
416        @Nullable
417        @Override
418        public MemberCodegen getParentCodegen() {
419            return codegen.getParentCodegen();
420        }
421    
422        public static CodegenContext getContext(DeclarationDescriptor descriptor, GenerationState state) {
423            if (descriptor instanceof PackageFragmentDescriptor) {
424                return new PackageContext((PackageFragmentDescriptor) descriptor, null, null);
425            }
426    
427            CodegenContext parent = getContext(descriptor.getContainingDeclaration(), state);
428    
429            if (descriptor instanceof ClassDescriptor) {
430                OwnerKind kind = DescriptorUtils.isTrait(descriptor) ? OwnerKind.TRAIT_IMPL : OwnerKind.IMPLEMENTATION;
431                return parent.intoClass((ClassDescriptor) descriptor, kind, state);
432            }
433            else if (descriptor instanceof FunctionDescriptor) {
434                return parent.intoFunction((FunctionDescriptor) descriptor);
435            }
436    
437            throw new IllegalStateException("Couldn't build context for " + descriptorName(descriptor));
438        }
439    
440        private static boolean isStaticMethod(FunctionDescriptor functionDescriptor, MethodContext context) {
441            return (getMethodAsmFlags(functionDescriptor, context.getContextKind()) & Opcodes.ACC_STATIC) != 0;
442        }
443    
444        @NotNull
445        public static String getNodeText(@Nullable MethodNode node) {
446            if (node == null) {
447                return "Not generated";
448            }
449            Textifier p = new Textifier();
450            node.accept(new TraceMethodVisitor(p));
451            StringWriter sw = new StringWriter();
452            p.print(new PrintWriter(sw));
453            sw.flush();
454            return node.name + " " + node.desc + ": \n " + sw.getBuffer().toString();
455        }
456    
457        private static String descriptorName(DeclarationDescriptor descriptor) {
458            return DescriptorRenderer.SHORT_NAMES_IN_TYPES.render(descriptor);
459        }
460    
461        @Override
462        public void genValueAndPut(
463                @NotNull ValueParameterDescriptor valueParameterDescriptor,
464                @NotNull JetExpression argumentExpression,
465                @NotNull Type parameterType
466        ) {
467            //TODO deparenthisise
468            if (isInliningClosure(argumentExpression, valueParameterDescriptor)) {
469                rememberClosure((JetFunctionLiteralExpression) argumentExpression, parameterType);
470            } else {
471                StackValue value = codegen.gen(argumentExpression);
472                if (shouldPutValue(parameterType, value, valueParameterDescriptor)) {
473                    value.put(parameterType, codegen.v);
474                }
475                afterParameterPut(parameterType, value, valueParameterDescriptor);
476            }
477        }
478    
479        @Override
480        public void putCapturedValueOnStack(
481                @NotNull StackValue stackValue, @NotNull Type valueType, int paramIndex
482        ) {
483            if (shouldPutValue(stackValue.type, stackValue, null)) {
484                stackValue.put(stackValue.type, codegen.v);
485            }
486            putCapturedInLocal(stackValue.type, stackValue, null, paramIndex);
487        }
488    }