Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
TransformArgumentSpecialization.hpp
Go to the documentation of this file.
1//===-- TransformArgumentSpecialization.hpp -- Specialize arguments --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_TRANSFORM_ARGUMENT_SPECIALIZATION_HPP
12#define PROTEUS_TRANSFORM_ARGUMENT_SPECIALIZATION_HPP
13
14#include <llvm/IR/IRBuilder.h>
15#include <llvm/Support/Debug.h>
16
18#include "proteus/Config.hpp"
19#include "proteus/Debug.h"
20#include "proteus/Logger.hpp"
22
23namespace proteus {
24
25using namespace llvm;
26
28private:
29 template <typename T>
30 static ArrayRef<T> createArrayRef(const RuntimeConstant &RC) {
31 T *TypedPtr = reinterpret_cast<T *>(RC.ArrInfo.Blob.get());
32 if (RC.ArrInfo.NumElts <= 0)
33 PROTEUS_FATAL_ERROR("Invalid number of elements in array: " +
34 std::to_string(RC.ArrInfo.NumElts));
35
36 return ArrayRef<T>(TypedPtr, RC.ArrInfo.NumElts);
37 }
38
39 static Constant *createConstantDataArray(Module &M,
40 const RuntimeConstant &RC) {
41 // Dispatch based on element type.
42 switch (RC.ArrInfo.EltType) {
44 return ConstantDataArray::get(M.getContext(), createArrayRef<bool>(RC));
46 return ConstantDataArray::get(M.getContext(), createArrayRef<int8_t>(RC));
48 return ConstantDataArray::get(M.getContext(),
49 createArrayRef<int32_t>(RC));
51 return ConstantDataArray::get(M.getContext(),
52 createArrayRef<int64_t>(RC));
54 return ConstantDataArray::get(M.getContext(), createArrayRef<float>(RC));
56 return ConstantDataArray::get(M.getContext(), createArrayRef<double>(RC));
57 default:
58 PROTEUS_FATAL_ERROR("Unsupported array element type: " +
60 }
61 }
62
63 static Constant *createConstantDataVector(Module &M,
64 const RuntimeConstant &RC) {
65 // Dispatch based on element type.
66 switch (RC.ArrInfo.EltType) {
68 return ConstantDataVector::get(M.getContext(),
69 createArrayRef<uint8_t>(RC));
71 return ConstantDataVector::get(M.getContext(),
72 createArrayRef<uint32_t>(RC));
74 return ConstantDataVector::get(M.getContext(),
75 createArrayRef<uint64_t>(RC));
77 return ConstantDataVector::get(M.getContext(), createArrayRef<float>(RC));
79 return ConstantDataVector::get(M.getContext(),
80 createArrayRef<double>(RC));
81 default:
82 PROTEUS_FATAL_ERROR("Unsupported vector element type: " +
84 }
85 }
86
87public:
88 static void transform(Module &M, Function &F,
89 ArrayRef<RuntimeConstant> RCArray) {
90 auto &Ctx = M.getContext();
91
92 // Replace argument uses with runtime constants.
93 for (const auto &RC : RCArray) {
94 int ArgNo = RC.Pos;
95 Argument *Arg = F.getArg(ArgNo);
96 Type *ArgType = Arg->getType();
97 Constant *C = nullptr;
98
99 switch (RC.Type) {
101 C = ConstantInt::get(ArgType, RC.Value.BoolVal);
102 break;
103 }
105 C = ConstantInt::get(ArgType, RC.Value.Int8Val);
106 break;
107 }
109 C = ConstantInt::get(ArgType, RC.Value.Int32Val);
110 break;
111 }
113 C = ConstantInt::get(ArgType, RC.Value.Int64Val);
114 break;
115 }
117 // Logger::logs("proteus") << "RC is Float\n";
118 C = ConstantFP::get(ArgType, RC.Value.FloatVal);
119 break;
120 }
122 C = ConstantFP::get(ArgType, RC.Value.DoubleVal);
123 break;
124 }
126 // NOTE: long double on device should correspond to plain double.
127 // XXX: CUDA with a long double SILENTLY fails to create a working
128 // kernel in AOT compilation, with or without JIT.
129 C = ConstantFP::get(ArgType, RC.Value.LongDoubleVal);
130 break;
131 }
133 auto *IntC = ConstantInt::get(Type::getInt64Ty(Ctx), RC.Value.Int64Val);
134 C = ConstantExpr::getIntToPtr(IntC, ArgType);
135 break;
136 }
138 Constant *CDA = createConstantDataArray(M, RC);
139 // Create a global variable to hold the array.
140 GlobalVariable *GV = new GlobalVariable(
141 M, CDA->getType(), true, GlobalValue::PrivateLinkage, CDA);
142
143 // Cast to the expected pointer type.
144 C = ConstantExpr::getBitCast(GV, ArgType);
145 break;
146 }
148 C = createConstantDataArray(M, RC);
149 break;
150 }
152 C = createConstantDataVector(M, RC);
153 break;
154 }
156 Constant *CDA = ConstantDataArray::getRaw(
157 StringRef{reinterpret_cast<const char *>(RC.ObjInfo.Blob.get()),
158 static_cast<size_t>(RC.ObjInfo.Size)},
159 RC.ObjInfo.Size, Type::getInt8Ty(M.getContext()));
160 // Create a global variable to hold the array.
161 GlobalVariable *GV = new GlobalVariable(
162 M, CDA->getType(), true, GlobalValue::PrivateLinkage, CDA);
163
164 // Cast to the expected pointer type.
165 C = ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, ArgType);
166 break;
167 }
168 default: {
169 std::string TypeString;
170 raw_string_ostream TypeOstream(TypeString);
171 ArgType->print(TypeOstream);
172 PROTEUS_FATAL_ERROR("JIT Incompatible type in runtime constant: " +
173 TypeOstream.str());
174 }
175 }
176
177 auto TraceOut = [](Function &F, int ArgNo, Constant *C) {
178 SmallString<128> S;
179 raw_svector_ostream OS(S);
180 OS << "[ArgSpec] Replaced Function " << F.getName() << " ArgNo "
181 << ArgNo << " with value " << *C->stripPointerCasts() << "\n";
182
183 return S;
184 };
185
186 PROTEUS_DBG(Logger::logs("proteus") << TraceOut(F, ArgNo, C));
187 if (Config::get().ProteusTraceOutput)
188 Logger::trace(TraceOut(F, ArgNo, C));
189 Arg->replaceAllUsesWith(C);
190 }
191 }
192};
193
194} // namespace proteus
195
196#endif
#define PROTEUS_DBG(x)
Definition Debug.h:10
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
static Config & get()
Definition Config.hpp:114
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:19
Definition TransformArgumentSpecialization.hpp:27
static void transform(Module &M, Function &F, ArrayRef< RuntimeConstant > RCArray)
Definition TransformArgumentSpecialization.hpp:88
Definition Helpers.h:76
Definition CppJitModule.cpp:21
@ ARRAY
Definition CompilerInterfaceTypes.h:33
@ VECTOR
Definition CompilerInterfaceTypes.h:32
@ INT32
Definition CompilerInterfaceTypes.h:25
@ INT64
Definition CompilerInterfaceTypes.h:26
@ FLOAT
Definition CompilerInterfaceTypes.h:27
@ STATIC_ARRAY
Definition CompilerInterfaceTypes.h:31
@ LONG_DOUBLE
Definition CompilerInterfaceTypes.h:29
@ INT8
Definition CompilerInterfaceTypes.h:24
@ BOOL
Definition CompilerInterfaceTypes.h:23
@ OBJECT
Definition CompilerInterfaceTypes.h:34
@ PTR
Definition CompilerInterfaceTypes.h:30
@ DOUBLE
Definition CompilerInterfaceTypes.h:28
std::string toString(CodegenOption Option)
Definition Config.hpp:23
int32_t NumElts
Definition CompilerInterfaceTypes.h:42
RuntimeConstantType EltType
Definition CompilerInterfaceTypes.h:43
std::shared_ptr< unsigned char[]> Blob
Definition CompilerInterfaceTypes.h:44
std::shared_ptr< unsigned char[]> Blob
Definition CompilerInterfaceTypes.h:53
int32_t Size
Definition CompilerInterfaceTypes.h:51
Definition CompilerInterfaceTypes.h:72
ArrayInfo ArrInfo
Definition CompilerInterfaceTypes.h:78
RuntimeConstantValue Value
Definition CompilerInterfaceTypes.h:73
RuntimeConstantType Type
Definition CompilerInterfaceTypes.h:74
ObjectInfo ObjInfo
Definition CompilerInterfaceTypes.h:79
int32_t Pos
Definition CompilerInterfaceTypes.h:75
double DoubleVal
Definition CompilerInterfaceTypes.h:65
int64_t Int64Val
Definition CompilerInterfaceTypes.h:63
bool BoolVal
Definition CompilerInterfaceTypes.h:60
int8_t Int8Val
Definition CompilerInterfaceTypes.h:61
int32_t Int32Val
Definition CompilerInterfaceTypes.h:62
float FloatVal
Definition CompilerInterfaceTypes.h:64
long double LongDoubleVal
Definition CompilerInterfaceTypes.h:66