Halide 14.0.0
Halide compiler and libraries
gpu_context_common.h
Go to the documentation of this file.
1#include "printer.h"
2#include "scoped_mutex_lock.h"
3
4namespace Halide {
5namespace Internal {
6
7template<typename ContextT, typename ModuleStateT>
9 struct CachedCompilation {
10 ContextT context{};
11 ModuleStateT module_state{};
12 uint32_t kernel_id{};
13 uint32_t use_count{0};
14
15 CachedCompilation(ContextT context, ModuleStateT module_state,
16 uint32_t kernel_id, uint32_t use_count)
17 : context(context), module_state(module_state),
18 kernel_id(kernel_id), use_count(use_count) {
19 }
20 };
21
22 halide_mutex mutex;
23
24 static constexpr float kLoadFactor{.5f};
25 static constexpr int kInitialTableBits{7};
26 int log2_compilations_size{0}; // number of bits in index into compilations table.
27 CachedCompilation *compilations{nullptr};
28 int count{0};
29
30 static constexpr uint32_t kInvalidId{0};
31 static constexpr uint32_t kDeletedId{1};
32
33 uint32_t unique_id{2}; // zero is an invalid id
34
35public:
36 static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits) {
37 uintptr_t addr = (uintptr_t)context + id;
38 // Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
39 // in hexadecimal.
40 if (sizeof(uintptr_t) >= 8) {
41 return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
42 } else {
43 return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
44 }
45 }
46
47 HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry) {
48 if (log2_compilations_size == 0) {
49 if (!resize_table(kInitialTableBits)) {
50 return false;
51 }
52 }
53 if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
54 if (!resize_table(log2_compilations_size + 1)) {
55 return false;
56 }
57 }
58 count += 1;
59 uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
60 for (int i = 0; i < (1 << log2_compilations_size); i++) {
61 uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
62 if (compilations[effective_index].kernel_id <= kDeletedId) {
63 compilations[effective_index] = entry;
64 return true;
65 }
66 }
67 // This is a logic error that should never occur. It means the table is
68 // full, but it should have been resized.
69 halide_debug_assert(nullptr, false);
70 return false;
71 }
72
74 ModuleStateT *&module_state, int increment) {
75 if (log2_compilations_size == 0) {
76 return false;
77 }
78 uintptr_t index = kernel_hash(context, id, log2_compilations_size);
79 for (int i = 0; i < (1 << log2_compilations_size); i++) {
80 uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
81
82 if (compilations[effective_index].kernel_id == kInvalidId) {
83 return false;
84 }
85 if (compilations[effective_index].context == context &&
86 compilations[effective_index].kernel_id == id) {
87 module_state = &compilations[effective_index].module_state;
88 if (increment != 0) {
89 compilations[effective_index].use_count += increment;
90 }
91 return true;
92 }
93 }
94 return false;
95 }
96
97 HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
98 ScopedMutexLock lock_guard(&mutex);
99 uint32_t id = (uint32_t)(uintptr_t)state_ptr;
100 ModuleStateT *mod_ptr;
101 if (find_internal(context, id, mod_ptr, 0)) {
102 module_state = *mod_ptr;
103 return true;
104 }
105 return false;
106 }
107
109 if (size_bits != log2_compilations_size) {
110 int new_size = (1 << size_bits);
111 int old_size = (1 << log2_compilations_size);
112 CachedCompilation *new_table = (CachedCompilation *)malloc(new_size * sizeof(CachedCompilation));
113 if (new_table == nullptr) {
114 // signal error.
115 return false;
116 }
117 memset(new_table, 0, new_size * sizeof(CachedCompilation));
118 CachedCompilation *old_table = compilations;
119 compilations = new_table;
120 log2_compilations_size = size_bits;
121
122 if (count > 0) { // Mainly to catch empty initial table case
123 for (int32_t i = 0; i < old_size; i++) {
124 if (old_table[i].kernel_id != kInvalidId &&
125 old_table[i].kernel_id != kDeletedId) {
126 bool result = insert(old_table[i]);
127 halide_debug_assert(nullptr, result); // Resizing the table while resizing the table is a logic error.
128 (void)result;
129 }
130 }
131 }
132 free(old_table);
133 }
134 return true;
135 }
136
137 template<typename FreeModuleT>
138 void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
139 if (count == 0) {
140 return;
141 }
142
143 for (int i = 0; i < (1 << log2_compilations_size); i++) {
144 if (compilations[i].kernel_id > kInvalidId &&
145 (all || (compilations[i].context == context)) &&
146 compilations[i].use_count == 0) {
147 debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
148 << " id " << compilations[i].kernel_id
149 << " context " << compilations[i].context << "\n";
150 f(compilations[i].module_state);
151 compilations[i].module_state = nullptr;
152 compilations[i].kernel_id = kDeletedId;
153 count--;
154 }
155 }
156 }
157
158 template<typename FreeModuleT>
159 void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
160 ScopedMutexLock lock_guard(&mutex);
161
162 release_context(user_context, false, context, f);
163 }
164
165 template<typename FreeModuleT>
166 void release_all(void *user_context, FreeModuleT &f) {
167 ScopedMutexLock lock_guard(&mutex);
168
169 release_context(user_context, true, nullptr, f);
170 // Some items may have been in use, so can't free.
171 if (count == 0) {
172 free(compilations);
173 compilations = nullptr;
174 log2_compilations_size = 0;
175 }
176 }
177
178 template<typename CompileModuleT, typename... Args>
179 HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr,
180 ContextT context, ModuleStateT &result,
181 CompileModuleT f,
182 Args... args) {
183 ScopedMutexLock lock_guard(&mutex);
184
185 uint32_t *id_ptr = (uint32_t *)state_ptr;
186 if (*id_ptr == 0) {
187 *id_ptr = unique_id++;
188 }
189
190 ModuleStateT *mod;
191 if (find_internal(context, *id_ptr, mod, 1)) {
192 result = *mod;
193 return true;
194 }
195
196 // TODO(zvookin): figure out the calling signature here...
197 ModuleStateT compiled_module = f(args...);
198 debug(user_context) << "Caching compiled kernel: " << compiled_module
199 << " id " << *id_ptr << " context " << context << "\n";
200 if (compiled_module == nullptr) {
201 return false;
202 }
203
204 if (!insert({context, compiled_module, *id_ptr, 1})) {
205 return false;
206 }
207 result = compiled_module;
208
209 return true;
210 }
211
212 void release_hold(void *user_context, ContextT context, void *state_ptr) {
213 ModuleStateT *mod;
214 uint32_t id = (uint32_t)(uintptr_t)state_ptr;
215 bool result = find_internal(context, id, mod, -1);
216 halide_debug_assert(user_context, result); // Value must be in cache to be released
217 (void)result;
218 }
219};
220
221} // namespace Internal
222} // namespace Halide
#define HALIDE_MUST_USE_RESULT
Definition: HalideRuntime.h:54
HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state)
static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits)
void release_hold(void *user_context, ContextT context, void *state_ptr)
HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry)
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr, ContextT context, ModuleStateT &result, CompileModuleT f, Args... args)
void release_all(void *user_context, FreeModuleT &f)
void delete_context(void *user_context, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool resize_table(int size_bits)
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uint32_t id, ModuleStateT *&module_state, int increment)
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1066
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
#define halide_debug_assert(user_context, cond)
halide_debug_assert() is like halide_assert(), but only expands into a check when DEBUG_RUNTIME is de...
void * malloc(size_t)
signed __INT32_TYPE__ int32_t
#define ALWAYS_INLINE
void * memset(void *s, int val, size_t n)
unsigned __INT32_TYPE__ uint32_t
void free(void *)
Cross-platform mutex.