Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
TransformLambdaSpecialization.hpp
Go to the documentation of this file.
1//===-- TransformLambdaSpecialization.hpp -- Specialize arguments --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_TRANSFORM_LAMBDA_SPECIALIZATION_HPP
12#define PROTEUS_TRANSFORM_LAMBDA_SPECIALIZATION_HPP
13
14#include <llvm/Demangle/Demangle.h>
15#include <llvm/IR/IRBuilder.h>
16#include <llvm/IR/Instructions.h>
17#include <llvm/Support/Casting.h>
18#include <llvm/Support/Debug.h>
19
21#include "proteus/Debug.h"
22#include "proteus/Utils.h"
23
24namespace proteus {
25
26using namespace llvm;
27
28inline Constant *getConstant(LLVMContext &Ctx, Type *ArgType,
29 const RuntimeConstant &RC) {
30 if (ArgType->isIntegerTy(1)) {
31 return ConstantInt::get(ArgType, RC.Value.BoolVal);
32 } else if (ArgType->isIntegerTy(8)) {
33 return ConstantInt::get(ArgType, RC.Value.Int8Val);
34 } else if (ArgType->isIntegerTy(32)) {
35 return ConstantInt::get(ArgType, RC.Value.Int32Val);
36 } else if (ArgType->isIntegerTy(64)) {
37 return ConstantInt::get(ArgType, RC.Value.Int64Val);
38 } else if (ArgType->isFloatTy()) {
39 return ConstantFP::get(ArgType, RC.Value.FloatVal);
40 } else if (ArgType->isDoubleTy()) {
41 return ConstantFP::get(ArgType, RC.Value.DoubleVal);
42 } else if (ArgType->isX86_FP80Ty() || ArgType->isPPC_FP128Ty() ||
43 ArgType->isFP128Ty()) {
44 return ConstantFP::get(ArgType, RC.Value.LongDoubleVal);
45 } else if (ArgType->isPointerTy()) {
46 auto *IntC = ConstantInt::get(Type::getInt64Ty(Ctx), RC.Value.Int64Val);
47 return ConstantExpr::getIntToPtr(IntC, ArgType);
48 } else {
49 std::string TypeString;
50 raw_string_ostream TypeOstream(TypeString);
51 ArgType->print(TypeOstream);
52 PROTEUS_FATAL_ERROR("JIT Incompatible type in runtime constant: " +
53 TypeOstream.str());
54 }
55}
56
58public:
59 static void transform(Module &M, Function &F,
60 const SmallVector<RuntimeConstant> &RCVec) {
61 auto *LambdaClass = F.getArg(0);
62 PROTEUS_DBG(Logger::logs("proteus")
63 << "[LambdaSpec] Function: " << F.getName() << " RCVec size "
64 << RCVec.size() << "\n");
65 PROTEUS_DBG(Logger::logs("proteus")
66 << "TransformLambdaSpecialization::transform" << "\n");
67 PROTEUS_DBG(Logger::logs("proteus") << "\t args" << "\n");
68#if PROTEUS_ENABLE_DEBUG
69 for (auto &Arg : RCVec) {
70 Logger::logs("proteus")
71 << "{" << Arg.Value.Int64Val << ", " << Arg.Slot << " }\n";
72 }
73#endif
74
75 auto TraceOut = [](int Slot, Constant *C) {
76 SmallString<128> S;
77 raw_svector_ostream OS(S);
78 OS << "[LambdaSpec] Replacing slot " << Slot << " with " << *C << "\n";
79
80 return S;
81 };
82
83 PROTEUS_DBG(Logger::logs("proteus") << "\t users" << "\n");
84 for (User *User : LambdaClass->users()) {
85 PROTEUS_DBG(Logger::logs("proteus") << *User << "\n");
86 if (isa<LoadInst>(User)) {
87 for (auto &Arg : RCVec) {
88 if (Arg.Slot == 0) {
89 Constant *C = getConstant(M.getContext(), User->getType(), Arg);
90 User->replaceAllUsesWith(C);
91 PROTEUS_DBG(Logger::logs("proteus") << TraceOut(Arg.Slot, C));
92 if (Config::get().ProteusTraceOutput)
93 Logger::trace(TraceOut(Arg.Slot, C));
94 }
95 }
96 } else if (auto *GEP = dyn_cast<GetElementPtrInst>(User)) {
97 auto *GEPSlot = GEP->getOperand(User->getNumOperands() - 1);
98 ConstantInt *CI = dyn_cast<ConstantInt>(GEPSlot);
99 int Slot = CI->getZExtValue();
100 for (auto &Arg : RCVec) {
101 if (Arg.Slot == Slot) {
102 for (auto *GEPUser : GEP->users()) {
103 auto *LI = dyn_cast<LoadInst>(GEPUser);
104 if (!LI)
105 PROTEUS_FATAL_ERROR("Expected load instruction");
106 Type *LoadType = LI->getType();
107 Constant *C = getConstant(M.getContext(), LoadType, Arg);
108 LI->replaceAllUsesWith(C);
109 PROTEUS_DBG(Logger::logs("proteus") << TraceOut(Arg.Slot, C));
110 if (Config::get().ProteusTraceOutput)
111 Logger::trace(TraceOut(Arg.Slot, C));
112 }
113 }
114 }
115 }
116 }
117 }
118};
119
120} // namespace proteus
121
122#endif
#define PROTEUS_DBG(x)
Definition Debug.h:10
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
static Config & get()
Definition Config.hpp:112
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:19
Definition TransformLambdaSpecialization.hpp:57
static void transform(Module &M, Function &F, const SmallVector< RuntimeConstant > &RCVec)
Definition TransformLambdaSpecialization.hpp:59
Definition Dispatcher.cpp:14
Constant * getConstant(LLVMContext &Ctx, Type *ArgType, const RuntimeConstant &RC)
Definition TransformLambdaSpecialization.hpp:28
Definition CompilerInterfaceTypes.h:38
RuntimeConstantValue Value
Definition CompilerInterfaceTypes.h:51