Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
TransformSharedArray.hpp
Go to the documentation of this file.
1//===-- TransformSharedArray.hpp -- Shared array with specialized size--===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_TRANSFORM_SHARED_ARRAY_HPP
12#define PROTEUS_TRANSFORM_SHARED_ARRAY_HPP
13
14#include <llvm/Analysis/ConstantFolding.h>
15#include <llvm/Demangle/Demangle.h>
16#include <llvm/IR/IRBuilder.h>
17#include <llvm/IR/Verifier.h>
18#include <llvm/Support/Debug.h>
19
20#include "proteus/Debug.h"
21#include "proteus/Logger.hpp"
22#include "proteus/Utils.h"
23
24namespace proteus {
25
26using namespace llvm;
27
29public:
30 static void transform(Module &M) {
31 for (auto &Func : M.functions()) {
32 std::string DemangledName = llvm::demangle(Func.getName().str());
33 StringRef StrRef{DemangledName};
34 if (StrRef.contains("proteus::shared_array")) {
35 // Use a while loop to delete while iterating.
36 while (!Func.user_empty()) {
37 User *Usr = *Func.user_begin();
38 if (!isa<CallBase>(Usr))
39 PROTEUS_FATAL_ERROR("Expected call user");
40
41 CallBase *CB = cast<CallBase>(Usr);
42 assert(CB->arg_size() == 2 && "Expected 2 arguments: N and sizeof");
43 int64_t N;
44 int64_t Sizeof;
45 if (!getConstantValue(CB->getArgOperand(0), N, M.getDataLayout()))
46 PROTEUS_FATAL_ERROR("Expected constant N argument");
47 if (!getConstantValue(CB->getArgOperand(1), Sizeof,
48 M.getDataLayout()))
49 PROTEUS_FATAL_ERROR("Expected constant Sizeof argument");
50
51 ArrayType *AType =
52 ArrayType::get(Type::getInt8Ty(M.getContext()), N * Sizeof);
53 constexpr unsigned SharedMemAddrSpace = 3;
54 GlobalVariable *SharedMemGV = new GlobalVariable(
55 M, AType, false, GlobalValue::InternalLinkage,
56 UndefValue::get(AType), ".proteus.shared", nullptr,
57 llvm::GlobalValue::NotThreadLocal, SharedMemAddrSpace, false);
58 // Using 16-byte alignment based on AOT code generation.
59 // TODO: Create or find an API to query the proper ABI alignment.
60 SharedMemGV->setAlignment(Align{16});
61
62 auto TraceOut = [](StringRef DemangledName,
63 GlobalVariable *SharedMemGV) {
64 SmallString<128> S;
65 raw_svector_ostream OS(S);
66 OS << "[SharedArray] " << "Replace CB " << DemangledName << " with "
67 << *SharedMemGV << "\n";
68
69 return S;
70 };
71
72 PROTEUS_DBG(Logger::logs("proteus")
73 << TraceOut(DemangledName, SharedMemGV));
74 if (Config::get().ProteusTraceOutput)
75 Logger::trace(TraceOut(DemangledName, SharedMemGV));
76
77 CB->replaceAllUsesWith(ConstantExpr::getAddrSpaceCast(
78 SharedMemGV, CB->getFunctionType()->getReturnType()));
79 CB->eraseFromParent();
80 }
81
82#if PROTEUS_ENABLE_DEBUG
83 if (verifyModule(M, &errs()))
84 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
85#endif
86 }
87 }
88 }
89
90private:
91 static bool getConstantValue(Value *V, int64_t &Result,
92 const DataLayout &DL) {
93 // Directly access the value of a constant.
94 if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
95 Result = CI->getSExtValue();
96 return true;
97 }
98
99 // Fold an instruction.
100 if (Instruction *I = dyn_cast<Instruction>(V)) {
101 if (Value *FoldedV = ConstantFoldInstruction(I, DL, nullptr)) {
102 if (ConstantInt *CI = dyn_cast<ConstantInt>(FoldedV)) {
103 Result = CI->getSExtValue();
104 return true;
105 }
106 }
107 }
108
109 return false;
110 }
111};
112
113} // namespace proteus
114
115#endif
#define PROTEUS_DBG(x)
Definition Debug.h:10
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
static Config & get()
Definition Config.hpp:112
Definition Func.hpp:19
StringRef getName()
Definition Func.hpp:117
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:19
Definition TransformSharedArray.hpp:28
static void transform(Module &M)
Definition TransformSharedArray.hpp:30
Definition Dispatcher.cpp:14