Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
TransformLambdaSpecialization.hpp
Go to the documentation of this file.
1//===-- TransformLambdaSpecialization.hpp -- Specialize arguments --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_TRANSFORM_LAMBDA_SPECIALIZATION_HPP
12#define PROTEUS_TRANSFORM_LAMBDA_SPECIALIZATION_HPP
13
14#include <llvm/Demangle/Demangle.h>
15#include <llvm/IR/IRBuilder.h>
16#include <llvm/IR/Instructions.h>
17#include <llvm/Support/Casting.h>
18#include <llvm/Support/Debug.h>
19
21#include "proteus/Debug.h"
22#include "proteus/Utils.h"
23
24namespace proteus {
25
26using namespace llvm;
27
28inline Constant *getConstant(LLVMContext &Ctx, Type *ArgType,
29 const RuntimeConstant &RC) {
30 switch (RC.Type) {
32 return ConstantInt::get(ArgType, RC.Value.BoolVal);
34 return ConstantInt::get(ArgType, RC.Value.Int8Val);
36 return ConstantInt::get(ArgType, RC.Value.Int32Val);
38 return ConstantInt::get(ArgType, RC.Value.Int64Val);
40 return ConstantFP::get(ArgType, RC.Value.FloatVal);
42 return ConstantFP::get(ArgType, RC.Value.DoubleVal);
44 return ConstantFP::get(ArgType, RC.Value.LongDoubleVal);
46 auto *IntC = ConstantInt::get(Type::getInt64Ty(Ctx), RC.Value.Int64Val);
47 return ConstantExpr::getIntToPtr(IntC, ArgType);
48 }
49 default:
50 std::string TypeString;
51 raw_string_ostream TypeOstream(TypeString);
52 ArgType->print(TypeOstream);
53 PROTEUS_FATAL_ERROR("JIT Incompatible type in runtime constant: " +
54 TypeOstream.str());
55 }
56}
57
59private:
60 static const RuntimeConstant *
61 findArgByOffset(const SmallVector<RuntimeConstant> &RCVec, int32_t Offset) {
62 for (auto &Arg : RCVec) {
63 if (Arg.Offset == Offset)
64 return &Arg;
65 }
66 return nullptr;
67 };
68
69 static const RuntimeConstant *
70 findArgByPos(const SmallVector<RuntimeConstant> &RCVec, int32_t Pos) {
71 for (auto &Arg : RCVec) {
72 if (Arg.Pos == Pos)
73 return &Arg;
74 }
75 return nullptr;
76 };
77
78 static auto traceOut(int Slot, Constant *C) {
79 SmallString<128> S;
80 raw_svector_ostream OS(S);
81 OS << "[LambdaSpec] Replacing slot " << Slot << " with " << *C << "\n";
82
83 return S;
84 };
85
86 static void handleLoad(Module &M, User *User,
87 const SmallVector<RuntimeConstant> &RCVec) {
88 auto *Arg = findArgByPos(RCVec, 0);
89 if (!Arg)
90 return;
91
92 Constant *C = getConstant(M.getContext(), User->getType(), *Arg);
93 User->replaceAllUsesWith(C);
94 PROTEUS_DBG(Logger::logs("proteus") << traceOut(Arg->Pos, C));
95 if (Config::get().ProteusTraceOutput >= 1)
96 Logger::trace(traceOut(Arg->Pos, C));
97 }
98
99 static void handleGEP(Module &M, GetElementPtrInst *GEP, User *User,
100 const SmallVector<RuntimeConstant> &RCVec) {
101 auto *GEPSlot = GEP->getOperand(User->getNumOperands() - 1);
102 ConstantInt *CI = dyn_cast<ConstantInt>(GEPSlot);
103 int Slot = CI->getZExtValue();
104 Type *SrcTy = GEP->getSourceElementType();
105
106 auto *Arg = SrcTy->isStructTy() ? findArgByPos(RCVec, Slot)
107 : findArgByOffset(RCVec, Slot);
108 if (!Arg)
109 return;
110
111 for (auto *GEPUser : GEP->users()) {
112 auto *LI = dyn_cast<LoadInst>(GEPUser);
113 if (!LI)
114 PROTEUS_FATAL_ERROR("Expected load instruction");
115 Type *LoadType = LI->getType();
116 Constant *C = getConstant(M.getContext(), LoadType, *Arg);
117 LI->replaceAllUsesWith(C);
118 PROTEUS_DBG(Logger::logs("proteus") << traceOut(Arg->Pos, C));
119 if (Config::get().ProteusTraceOutput >= 1)
120 Logger::trace(traceOut(Arg->Pos, C));
121 }
122 }
123
124public:
125 static void transform(Module &M, Function &F,
126 const SmallVector<RuntimeConstant> &RCVec) {
127 auto *LambdaClass = F.getArg(0);
128 PROTEUS_DBG(Logger::logs("proteus")
129 << "[LambdaSpec] Function: " << F.getName() << " RCVec size "
130 << RCVec.size() << "\n");
131 PROTEUS_DBG(Logger::logs("proteus")
132 << "TransformLambdaSpecialization::transform" << "\n");
133 PROTEUS_DBG(Logger::logs("proteus") << "\t args" << "\n");
134#if PROTEUS_ENABLE_DEBUG
135 for (auto &Arg : RCVec) {
136 Logger::logs("proteus")
137 << "{" << Arg.Value.Int64Val << ", " << Arg.Pos << " }\n";
138 }
139#endif
140
141 PROTEUS_DBG(Logger::logs("proteus") << "\t users" << "\n");
142 for (User *User : LambdaClass->users()) {
143 PROTEUS_DBG(Logger::logs("proteus") << *User << "\n");
144 if (isa<LoadInst>(User))
145 handleLoad(M, User, RCVec);
146 else if (auto *GEP = dyn_cast<GetElementPtrInst>(User))
147 handleGEP(M, GEP, User, RCVec);
148 }
149 }
150};
151
152} // namespace proteus
153
154#endif
#define PROTEUS_DBG(x)
Definition Debug.h:10
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
static Config & get()
Definition Config.hpp:114
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:19
Definition TransformLambdaSpecialization.hpp:58
static void transform(Module &M, Function &F, const SmallVector< RuntimeConstant > &RCVec)
Definition TransformLambdaSpecialization.hpp:125
Definition Helpers.h:76
Definition BuiltinsCUDA.cpp:4
static int Pos
Definition JitInterface.hpp:105
@ INT32
Definition CompilerInterfaceTypes.h:25
@ INT64
Definition CompilerInterfaceTypes.h:26
@ FLOAT
Definition CompilerInterfaceTypes.h:27
@ LONG_DOUBLE
Definition CompilerInterfaceTypes.h:29
@ INT8
Definition CompilerInterfaceTypes.h:24
@ BOOL
Definition CompilerInterfaceTypes.h:23
@ PTR
Definition CompilerInterfaceTypes.h:30
@ DOUBLE
Definition CompilerInterfaceTypes.h:28
Constant * getConstant(LLVMContext &Ctx, Type *ArgType, const RuntimeConstant &RC)
Definition TransformLambdaSpecialization.hpp:28
Definition CompilerInterfaceTypes.h:72
RuntimeConstantValue Value
Definition CompilerInterfaceTypes.h:73
RuntimeConstantType Type
Definition CompilerInterfaceTypes.h:74
double DoubleVal
Definition CompilerInterfaceTypes.h:65
int64_t Int64Val
Definition CompilerInterfaceTypes.h:63
bool BoolVal
Definition CompilerInterfaceTypes.h:60
int8_t Int8Val
Definition CompilerInterfaceTypes.h:61
int32_t Int32Val
Definition CompilerInterfaceTypes.h:62
float FloatVal
Definition CompilerInterfaceTypes.h:64
long double LongDoubleVal
Definition CompilerInterfaceTypes.h:66