Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitInterface.h
Go to the documentation of this file.
1//===-- jit.h -- user interface to Proteus JIT library --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11// NOLINTBEGIN(readability-identifier-naming)
12
13#ifndef PROTEUS_JIT_INTERFACE_H
14#define PROTEUS_JIT_INTERFACE_H
15
17#include "proteus/Init.h"
18
19#include <cassert>
20#include <cstring>
21#include <type_traits>
22#include <utility>
23
25 const char *AssociatedLambda);
26extern "C" void __jit_register_lambda(const char *Symbol);
27extern "C" void __jit_take_address(void const *) noexcept;
28
29namespace proteus {
30
31template <typename T> __attribute__((noinline)) void jit_arg(T V) noexcept;
32#if defined(__CUDACC__) || defined(__HIP__)
33template <typename T>
34__attribute__((noinline)) __device__ void jit_arg(T V) noexcept;
35#endif
36
37template <typename T>
38__attribute__((noinline)) void
39jit_array(T V, [[maybe_unused]] size_t NumElts,
40 [[maybe_unused]]
41 typename std::remove_pointer<T>::type Velem = 0) noexcept;
42#if defined(__CUDACC__) || defined(__HIP__)
43template <typename T>
44__attribute__((noinline)) __device__ void
45jit_array(T V, [[maybe_unused]] size_t NumElts,
46 [[maybe_unused]]
47 typename std::remove_pointer<T>::type Velem = 0) noexcept;
48#endif
49
50template <typename T>
51__attribute__((noinline))
52std::enable_if_t<std::is_trivially_copyable_v<std::remove_pointer_t<T>>, void>
53jit_object(T *V, size_t Size = sizeof(std::remove_pointer_t<T>)) noexcept;
54
55#if defined(__CUDACC__) || defined(__HIP__)
56template <typename T>
57__attribute__((noinline)) __device__ std::enable_if_t<
58 std::is_trivially_copyable_v<std::remove_pointer_t<T>>, void>
59jit_object(T *V, size_t Size = sizeof(T)) noexcept;
60#endif
61
62template <typename T>
63__attribute__((noinline))
64std::enable_if_t<!std::is_pointer_v<T> &&
65 std::is_trivially_copyable_v<std::remove_reference_t<T>>,
66 void>
67jit_object(T &V, size_t Size = sizeof(std::remove_reference_t<T>)) noexcept;
68
69#if defined(__CUDACC__) || defined(__HIP__)
70template <typename T>
71__attribute__((noinline)) __device__ std::enable_if_t<
72 !std::is_pointer_v<T> &&
73 std::is_trivially_copyable_v<std::remove_reference_t<T>>,
74 void>
75jit_object(T &V, size_t Size = sizeof(T)) noexcept;
76#endif
77
78template <typename T> inline static RuntimeConstantType convertCTypeToRCType() {
79 if constexpr (std::is_same_v<T, bool>) {
81 } else if constexpr (std::is_integral_v<T> && sizeof(T) == sizeof(int8_t)) {
83 } else if constexpr (std::is_integral_v<T> && sizeof(T) == sizeof(int32_t)) {
85 } else if constexpr (std::is_integral_v<T> && sizeof(T) == sizeof(int64_t)) {
87 } else if constexpr (std::is_same_v<T, float>) {
89 } else if constexpr (std::is_same_v<T, double>) {
91 } else if constexpr (std::is_same_v<T, long double>) {
93 } else if constexpr (std::is_pointer_v<T>) {
95 } else {
97 }
98}
99
100template <typename T>
101static __attribute__((noinline)) T
102jit_variable(T V, int Pos = -1, int Offset = -1,
103 const char *AssociatedLambda = "") noexcept {
104 RuntimeConstant RC{convertCTypeToRCType<T>(), Pos, Offset};
105 std::memcpy(static_cast<void *>(&RC), &V, sizeof(T));
106 __jit_register_variable(RC, AssociatedLambda);
107
108 return V;
109}
110
111template <typename T>
112static __attribute__((noinline)) T &&
113register_lambda(T &&t, const char *Symbol = "") noexcept {
114 assert(Symbol && "Expected non-null Symbol");
115 __jit_register_lambda(Symbol);
116 // Force LLVM to generate an AllocaInst of the underlying Clang--generated
117 // anonymous class for T. We remove this after recording the demangled
118 // lambda name.
119 using LambdaType = std::decay_t<T>;
120 LambdaType local = t;
121 __jit_take_address(&local);
122 return std::forward<T>(t);
123}
124
125#if defined(__CUDACC__) || defined(__HIP__)
126// The function needs to be static for RDC compilation to resolve the static
127// shared memory fallback.
128template <typename T, size_t MAXN, int UniqueID = 0>
129static __device__ __attribute__((noinline)) T *
130shared_array([[maybe_unused]] size_t N,
131 [[maybe_unused]] size_t ElemSize = sizeof(T)) {
132 alignas(T) static __shared__ char shmem[sizeof(T) * MAXN];
133 return reinterpret_cast<T *>(shmem);
134}
135#endif
136
137} // namespace proteus
138
139#endif
140
141// NOLINTEND(readability-identifier-naming)
void __jit_take_address(void const *) noexcept
void __jit_register_variable(proteus::RuntimeConstant RC, const char *AssociatedLambda)
void __jit_register_lambda(const char *Symbol)
Definition MemoryCache.h:26
size_t std::remove_pointer< T >::type Velem
Definition JitInterface.h:41
static int Pos
Definition JitInterface.h:102
RuntimeConstantType
Definition CompilerInterfaceTypes.h:20
@ NONE
Definition CompilerInterfaceTypes.h:22
@ INT32
Definition CompilerInterfaceTypes.h:25
@ INT64
Definition CompilerInterfaceTypes.h:26
@ FLOAT
Definition CompilerInterfaceTypes.h:27
@ LONG_DOUBLE
Definition CompilerInterfaceTypes.h:29
@ INT8
Definition CompilerInterfaceTypes.h:24
@ BOOL
Definition CompilerInterfaceTypes.h:23
@ PTR
Definition CompilerInterfaceTypes.h:30
@ DOUBLE
Definition CompilerInterfaceTypes.h:28
static int int Offset
Definition JitInterface.h:102
size_t NumElts
Definition JitInterface.h:39
__attribute__((noinline)) void jit_arg(T V) noexcept
Definition CompilerInterfaceTypes.h:72