1 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_HPP
2 #define VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_HPP
43 template <
typename StringType>
48 source.append(
" unsigned int row_gid = get_global_id(0) / get_local_size(0);\n");
49 source.append(
" unsigned int col_gid = get_global_id(0) % get_local_size(0);\n");
50 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_num_groups(0))\n");
51 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_local_size(0))\n");
55 source.append(
" unsigned int col_gid = get_global_id(0) / get_local_size(0);\n");
56 source.append(
" unsigned int row_gid = get_global_id(0) % get_local_size(0);\n");
57 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_num_groups(0))\n");
58 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_local_size(0))\n");
64 source.append(
" A[(row * A_inc1 + A_start1) * A_internal_size2 + (col * A_inc2 + A_start2)] ");
66 source.append(
" A[(row * A_inc1 + A_start1) + (col * A_inc2 + A_start2) * A_internal_size1] ");
69 source.append(
" B[(row * B_inc1 + B_start1) * B_internal_size2 + (col * B_inc2 + B_start2)] ");
71 source.append(
" B[(row * B_inc1 + B_start1) + (col * B_inc2 + B_start2) * B_internal_size1] ");
74 source.append(
"* alpha ");
76 source.append(
"/ alpha ");
80 source.append(
"+ C[(row * C_inc1 + C_start1) * C_internal_size2 + (col * C_inc2 + C_start2)] ");
82 source.append(
"+ C[(row * C_inc1 + C_start1) + (col * C_inc2 + C_start2) * C_internal_size1] ");
84 source.append(
"* beta");
86 source.append(
"/ beta");
92 source.append(
" A[row * A_internal_size2 + col] ");
94 source.append(
" A[row + col * A_internal_size1] ");
97 source.append(
" B[row * B_internal_size2 + col] ");
99 source.append(
" B[row + col * B_internal_size1] ");
102 source.append(
"* alpha ");
104 source.append(
"/ alpha ");
108 source.append(
"+ C[row * C_internal_size2 + col] ");
110 source.append(
"+ C[row + col * C_internal_size2] ");
112 source.append(
"* beta");
114 source.append(
"/ beta");
117 source.append(
"; \n");
120 template <
typename StringType>
123 source.append(
"__kernel void am");
130 source.append(
"_cpu");
132 source.append(
"_gpu");
135 source.append(
"_cpu");
137 source.append(
"_gpu");
138 source.append(
"( \n");
139 source.append(
" __global "); source.append(numeric_string); source.append(
" * A, \n");
140 source.append(
" unsigned int A_start1, unsigned int A_start2, \n");
141 source.append(
" unsigned int A_inc1, unsigned int A_inc2, \n");
142 source.append(
" unsigned int A_size1, unsigned int A_size2, \n");
143 source.append(
" unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
146 source.append(
" "); source.append(numeric_string); source.append(
" fac2, \n");
150 source.append(
" __global "); source.append(numeric_string); source.append(
" * fac2, \n");
152 source.append(
" unsigned int options2, \n");
153 source.append(
" __global const "); source.append(numeric_string); source.append(
" * B, \n");
154 source.append(
" unsigned int B_start1, unsigned int B_start2, \n");
155 source.append(
" unsigned int B_inc1, unsigned int B_inc2, \n");
156 source.append(
" unsigned int B_internal_size1, unsigned int B_internal_size2");
160 source.append(
", \n\n");
163 source.append(
" "); source.append(numeric_string); source.append(
" fac3, \n");
167 source.append(
" __global "); source.append(numeric_string); source.append(
" * fac3, \n");
169 source.append(
" unsigned int options3, \n");
170 source.append(
" __global const "); source.append(numeric_string); source.append(
" * C, \n");
171 source.append(
" unsigned int C_start1, unsigned int C_start2, \n");
172 source.append(
" unsigned int C_inc1, unsigned int C_inc2, \n");
173 source.append(
" unsigned int C_internal_size1, unsigned int C_internal_size2 \n");
175 source.append(
") { \n");
179 source.append(
" "); source.append(numeric_string); source.append(
" alpha = fac2; \n");
183 source.append(
" "); source.append(numeric_string); source.append(
" alpha = fac2[0]; \n");
185 source.append(
" if (options2 & (1 << 0)) \n");
186 source.append(
" alpha = -alpha; \n");
187 source.append(
" \n");
191 source.append(
" "); source.append(numeric_string); source.append(
" beta = fac3; \n");
195 source.append(
" "); source.append(numeric_string); source.append(
" beta = fac3[0]; \n");
199 source.append(
" if (options3 & (1 << 0)) \n");
200 source.append(
" beta = -beta; \n");
201 source.append(
" \n");
203 source.append(
" if (options2 & (1 << 1)) { \n");
206 source.append(
" if (options3 & (1 << 1)) {\n");
208 source.append(
" } else {\n");
210 source.append(
" } \n");
214 source.append(
" } else { \n");
217 source.append(
" if (options3 & (1 << 1)) {\n");
219 source.append(
" } else {\n");
221 source.append(
" } \n");
225 source.append(
" } \n");
226 source.append(
"} \n");
229 template <
typename StringType>
256 template <
typename StringType>
259 source.append(
"__kernel void assign_cpu( \n");
260 source.append(
" __global "); source.append(numeric_string); source.append(
" * A, \n");
261 source.append(
" unsigned int A_start1, unsigned int A_start2, \n");
262 source.append(
" unsigned int A_inc1, unsigned int A_inc2, \n");
263 source.append(
" unsigned int A_size1, unsigned int A_size2, \n");
264 source.append(
" unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
265 source.append(
" "); source.append(numeric_string); source.append(
" alpha) \n");
266 source.append(
"{ \n");
269 source.append(
" unsigned int row_gid = get_global_id(0) / get_local_size(0);\n");
270 source.append(
" unsigned int col_gid = get_global_id(0) % get_local_size(0);\n");
271 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_num_groups(0))\n");
272 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_local_size(0))\n");
273 source.append(
" A[(row * A_inc1 + A_start1) * A_internal_size2 + (col * A_inc2 + A_start2)] = alpha; \n");
277 source.append(
" unsigned int row_gid = get_global_id(0) % get_local_size(0);\n");
278 source.append(
" unsigned int col_gid = get_global_id(0) / get_local_size(0);\n");
279 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_num_groups(0))\n");
280 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_local_size(0))\n");
281 source.append(
" A[(row * A_inc1 + A_start1) + (col * A_inc2 + A_start2) * A_internal_size1] = alpha; \n");
283 source.append(
"} \n");
286 template <
typename StringType>
289 source.append(
"__kernel void diagonal_assign_cpu( \n");
290 source.append(
" __global "); source.append(numeric_string); source.append(
" * A, \n");
291 source.append(
" unsigned int A_start1, unsigned int A_start2, \n");
292 source.append(
" unsigned int A_inc1, unsigned int A_inc2, \n");
293 source.append(
" unsigned int A_size1, unsigned int A_size2, \n");
294 source.append(
" unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
295 source.append(
" "); source.append(numeric_string); source.append(
" alpha) \n");
296 source.append(
"{ \n");
297 source.append(
" for (unsigned int idx = get_global_id(0); idx < min(A_size1, A_size2); idx += get_global_size(0))\n");
299 source.append(
" A[(idx * A_inc1 + A_start1) * A_internal_size2 + (idx * A_inc2 + A_start2)] = alpha; \n");
301 source.append(
" A[(idx * A_inc1 + A_start1) + (idx * A_inc2 + A_start2) * A_internal_size1] = alpha; \n");
302 source.append(
"} \n");
305 template <
typename StringType>
308 source.append(
"__kernel void element_op( \n");
309 source.append(
" __global "); source.append(numeric_string); source.append(
" * A, \n");
310 source.append(
" unsigned int A_start1, unsigned int A_start2, \n");
311 source.append(
" unsigned int A_inc1, unsigned int A_inc2, \n");
312 source.append(
" unsigned int A_size1, unsigned int A_size2, \n");
313 source.append(
" unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
314 source.append(
" __global "); source.append(numeric_string); source.append(
" * B, \n");
315 source.append(
" unsigned int B_start1, unsigned int B_start2, \n");
316 source.append(
" unsigned int B_inc1, unsigned int B_inc2, \n");
317 source.append(
" unsigned int B_internal_size1, unsigned int B_internal_size2, \n");
318 source.append(
" __global "); source.append(numeric_string); source.append(
" * C, \n");
319 source.append(
" unsigned int C_start1, unsigned int C_start2, \n");
320 source.append(
" unsigned int C_inc1, unsigned int C_inc2, \n");
321 source.append(
" unsigned int C_internal_size1, unsigned int C_internal_size2, \n");
322 source.append(
" unsigned int op_type) \n");
323 source.append(
"{ \n");
326 source.append(
" unsigned int row_gid = get_global_id(0) / get_local_size(0);\n");
327 source.append(
" unsigned int col_gid = get_global_id(0) % get_local_size(0);\n");
328 source.append(
" if (op_type == 2) {");
329 if (numeric_string ==
"float" || numeric_string ==
"double")
331 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_num_groups(0))\n");
332 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_local_size(0))\n");
333 source.append(
" A[(row * A_inc1 + A_start1) * A_internal_size2 + (col * A_inc2 + A_start2)] = \n");
334 source.append(
" pow(B[(row * B_inc1 + B_start1) * B_internal_size2 + (col * B_inc2 + B_start2)], \n");
335 source.append(
" C[(row * C_inc1 + C_start1) * C_internal_size2 + (col * C_inc2 + C_start2)]); \n");
337 source.append(
" } else if (op_type == 1) {");
338 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_num_groups(0))\n");
339 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_local_size(0))\n");
340 source.append(
" A[(row * A_inc1 + A_start1) * A_internal_size2 + (col * A_inc2 + A_start2)] = \n");
341 source.append(
" B[(row * B_inc1 + B_start1) * B_internal_size2 + (col * B_inc2 + B_start2)] / \n");
342 source.append(
" C[(row * C_inc1 + C_start1) * C_internal_size2 + (col * C_inc2 + C_start2)]; \n");
343 source.append(
" } else if (op_type == 0) {");
344 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_num_groups(0))\n");
345 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_local_size(0))\n");
346 source.append(
" A[(row * A_inc1 + A_start1) * A_internal_size2 + (col * A_inc2 + A_start2)] = \n");
347 source.append(
" B[(row * B_inc1 + B_start1) * B_internal_size2 + (col * B_inc2 + B_start2)] * \n");
348 source.append(
" C[(row * C_inc1 + C_start1) * C_internal_size2 + (col * C_inc2 + C_start2)]; \n");
353 source.append(
" unsigned int row_gid = get_global_id(0) % get_local_size(0);\n");
354 source.append(
" unsigned int col_gid = get_global_id(0) / get_local_size(0);\n");
355 source.append(
" if (op_type == 2) {");
356 if (numeric_string ==
"float" || numeric_string ==
"double")
358 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_num_groups(0))\n");
359 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_local_size(0))\n");
360 source.append(
" A[(row * A_inc1 + A_start1) + (col * A_inc2 + A_start2) * A_internal_size1] = \n");
361 source.append(
" pow(B[(row * B_inc1 + B_start1) + (col * B_inc2 + B_start2) * B_internal_size1], \n");
362 source.append(
" C[(row * C_inc1 + C_start1) + (col * C_inc2 + C_start2) * C_internal_size1]); \n");
364 source.append(
" } else if (op_type == 1) {");
365 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_num_groups(0))\n");
366 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_local_size(0))\n");
367 source.append(
" A[(row * A_inc1 + A_start1) + (col * A_inc2 + A_start2) * A_internal_size1] = \n");
368 source.append(
" B[(row * B_inc1 + B_start1) + (col * B_inc2 + B_start2) * B_internal_size1] / \n");
369 source.append(
" C[(row * C_inc1 + C_start1) + (col * C_inc2 + C_start2) * C_internal_size1]; \n");
370 source.append(
" } else if (op_type == 0) {");
371 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_num_groups(0))\n");
372 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_local_size(0))\n");
373 source.append(
" A[(row * A_inc1 + A_start1) + (col * A_inc2 + A_start2) * A_internal_size1] = \n");
374 source.append(
" B[(row * B_inc1 + B_start1) + (col * B_inc2 + B_start2) * B_internal_size1] * \n");
375 source.append(
" C[(row * C_inc1 + C_start1) + (col * C_inc2 + C_start2) * C_internal_size1]; \n");
378 source.append(
"} \n");
382 template <
typename StringType>
386 source.append(
"__kernel void fft_direct(__global "); source.append(numeric_string); source.append(
"2 *input, \n");
387 source.append(
" __global "); source.append(numeric_string); source.append(
"2 *output, \n");
388 source.append(
" unsigned int size, \n");
389 source.append(
" unsigned int stride, \n");
390 source.append(
" unsigned int batch_num, \n");
391 source.append(
" "); source.append(numeric_string); source.append(
" sign) { \n");
392 source.append(
" const "); source.append(numeric_string); source.append(
" NUM_PI = 3.14159265358979323846; \n");
393 source.append(
" \n");
394 source.append(
" for(unsigned int batch_id = 0; batch_id < batch_num; batch_id++) { \n");
395 source.append(
" for(unsigned int k = get_global_id(0); k < size; k += get_global_size(0)) { \n");
396 source.append(
" "); source.append(numeric_string); source.append(
"2 f = 0.0f; \n");
397 source.append(
" \n");
398 source.append(
" for(unsigned int n = 0; n < size; n++) { \n");
399 source.append(
" "); source.append(numeric_string); source.append(
"2 in = ");
401 source.append(
"input[batch_id * stride + n]; \n");
403 source.append(
"input[n * stride + batch_id]; \n");
404 source.append(
" \n");
405 source.append(
" "); source.append(numeric_string); source.append(
" sn, cs; \n");
406 source.append(
" "); source.append(numeric_string); source.append(
" arg = sign * 2 * NUM_PI * k / size * n; \n");
407 source.append(
" sn = sincos(arg, &cs); \n");
408 source.append(
" \n");
409 source.append(
" "); source.append(numeric_string); source.append(
"2 ex = ("); source.append(numeric_string); source.append(
"2)(cs, sn); \n");
410 source.append(
" f = f + ("); source.append(numeric_string); source.append(
"2)(in.x * ex.x - in.y * ex.y, in.x * ex.y + in.y * ex.x); \n");
411 source.append(
" } \n");
412 source.append(
" \n");
414 source.append(
" output[batch_id * stride + k] = f; \n");
416 source.append(
" output[k * stride + batch_id] = f; \n");
417 source.append(
" } \n");
418 source.append(
" } \n");
419 source.append(
"} \n");
421 source.append(
" \n");
423 source.append(
"__kernel void fft_radix2(__global "); source.append(numeric_string); source.append(
"2* input, \n");
424 source.append(
" unsigned int s, \n");
425 source.append(
" unsigned int bit_size, \n");
426 source.append(
" unsigned int size, \n");
427 source.append(
" unsigned int stride, \n");
428 source.append(
" unsigned int batch_num, \n");
429 source.append(
" "); source.append(numeric_string); source.append(
" sign) { \n");
430 source.append(
" \n");
431 source.append(
" unsigned int ss = 1 << s; \n");
432 source.append(
" unsigned int half_size = size >> 1; \n");
433 source.append(
" \n");
434 source.append(
" "); source.append(numeric_string); source.append(
" cs, sn; \n");
435 source.append(
" const "); source.append(numeric_string); source.append(
" NUM_PI = 3.14159265358979323846; \n");
436 source.append(
" \n");
437 source.append(
" unsigned int glb_id = get_global_id(0); \n");
438 source.append(
" unsigned int glb_sz = get_global_size(0); \n");
440 source.append(
" for(unsigned int batch_id = 0; batch_id < batch_num; batch_id++) { \n");
441 source.append(
" for(unsigned int tid = glb_id; tid < half_size; tid += glb_sz) { \n");
442 source.append(
" unsigned int group = (tid & (ss - 1)); \n");
443 source.append(
" unsigned int pos = ((tid >> s) << (s + 1)) + group; \n");
447 source.append(
" unsigned int offset = batch_id * stride + pos; \n");
448 source.append(
" "); source.append(numeric_string); source.append(
"2 in1 = input[offset]; \n");
449 source.append(
" "); source.append(numeric_string); source.append(
"2 in2 = input[offset + ss]; \n");
453 source.append(
" unsigned int offset = pos * stride + batch_id; \n");
454 source.append(
" "); source.append(numeric_string); source.append(
"2 in1 = input[offset]; \n");
455 source.append(
" "); source.append(numeric_string); source.append(
"2 in2 = input[offset + ss * stride]; \n");
458 source.append(
" "); source.append(numeric_string); source.append(
" arg = group * sign * NUM_PI / ss; \n");
460 source.append(
" sn = sincos(arg, &cs); \n");
462 source.append(
" "); source.append(numeric_string); source.append(
"2 ex = ("); source.append(numeric_string); source.append(
"2)(cs, sn); \n");
464 source.append(
" "); source.append(numeric_string); source.append(
"2 tmp = ("); source.append(numeric_string); source.append(
"2)(in2.x * ex.x - in2.y * ex.y, in2.x * ex.y + in2.y * ex.x); \n");
467 source.append(
" input[offset + ss] = in1 - tmp; \n");
469 source.append(
" input[offset + ss * stride] = in1 - tmp; \n");
470 source.append(
" input[offset] = in1 + tmp; \n");
471 source.append(
" } \n");
472 source.append(
" } \n");
473 source.append(
"} \n");
475 source.append(
" \n");
477 source.append(
" unsigned int get_reorder_num(unsigned int v, unsigned int bit_size) { \n");
478 source.append(
" v = ((v >> 1) & 0x55555555) | ((v & 0x55555555) << 1); \n");
479 source.append(
" v = ((v >> 2) & 0x33333333) | ((v & 0x33333333) << 2); \n");
480 source.append(
" v = ((v >> 4) & 0x0F0F0F0F) | ((v & 0x0F0F0F0F) << 4); \n");
481 source.append(
" v = ((v >> 8) & 0x00FF00FF) | ((v & 0x00FF00FF) << 8); \n");
482 source.append(
" v = (v >> 16) | (v << 16); \n");
483 source.append(
" \n");
484 source.append(
" v = v >> (32 - bit_size); \n");
485 source.append(
" \n");
486 source.append(
" return v; \n");
487 source.append(
" } \n");
489 source.append(
" __kernel void fft_radix2_local(__global "); source.append(numeric_string); source.append(
"2* input, \n");
490 source.append(
" __local "); source.append(numeric_string); source.append(
"2* lcl_input, \n");
491 source.append(
" unsigned int bit_size, \n");
492 source.append(
" unsigned int size, \n");
493 source.append(
" unsigned int stride, \n");
494 source.append(
" unsigned int batch_num, \n");
495 source.append(
" "); source.append(numeric_string); source.append(
" sign) { \n");
497 source.append(
" unsigned int grp_id = get_group_id(0); \n");
498 source.append(
" unsigned int grp_num = get_num_groups(0); \n");
500 source.append(
" unsigned int lcl_sz = get_local_size(0); \n");
501 source.append(
" unsigned int lcl_id = get_local_id(0); \n");
502 source.append(
" const "); source.append(numeric_string); source.append(
" NUM_PI = 3.14159265358979323846; \n");
504 source.append(
" for(unsigned int batch_id = grp_id; batch_id < batch_num; batch_id += grp_num) { \n");
507 source.append(
" for(unsigned int p = lcl_id; p < size; p += lcl_sz) { \n");
508 source.append(
" unsigned int v = get_reorder_num(p, bit_size); \n");
510 source.append(
" lcl_input[v] = input[batch_id * stride + p]; \n");
512 source.append(
" lcl_input[v] = input[p * stride + batch_id]; \n");
513 source.append(
" } \n");
515 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
518 source.append(
" for(unsigned int s = 0; s < bit_size; s++) { \n");
519 source.append(
" unsigned int ss = 1 << s; \n");
521 source.append(
" "); source.append(numeric_string); source.append(
" cs, sn; \n");
523 source.append(
" for(unsigned int tid = lcl_id; tid < size; tid += lcl_sz) { \n");
524 source.append(
" unsigned int group = (tid & (ss - 1)); \n");
525 source.append(
" unsigned int pos = ((tid >> s) << (s + 1)) + group; \n");
527 source.append(
" "); source.append(numeric_string); source.append(
"2 in1 = lcl_input[pos]; \n");
528 source.append(
" "); source.append(numeric_string); source.append(
"2 in2 = lcl_input[pos + ss]; \n");
530 source.append(
" "); source.append(numeric_string); source.append(
" arg = group * sign * NUM_PI / ss; \n");
532 source.append(
" sn = sincos(arg, &cs); \n");
533 source.append(
" "); source.append(numeric_string); source.append(
"2 ex = ("); source.append(numeric_string); source.append(
"2)(cs, sn); \n");
535 source.append(
" "); source.append(numeric_string); source.append(
"2 tmp = ("); source.append(numeric_string); source.append(
"2)(in2.x * ex.x - in2.y * ex.y, in2.x * ex.y + in2.y * ex.x); \n");
537 source.append(
" lcl_input[pos + ss] = in1 - tmp; \n");
538 source.append(
" lcl_input[pos] = in1 + tmp; \n");
539 source.append(
" } \n");
541 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
542 source.append(
" } \n");
545 source.append(
" for(unsigned int p = lcl_id; p < size; p += lcl_sz) { \n");
547 source.append(
" input[batch_id * stride + p] = lcl_input[p]; \n");
549 source.append(
" input[p * stride + batch_id] = lcl_input[p]; \n");
550 source.append(
" } \n");
551 source.append(
" } \n");
552 source.append(
" } \n");
554 source.append(
" \n");
560 source.append(
"unsigned int get_reorder_num_2(unsigned int v, unsigned int bit_size) { \n");
561 source.append(
" v = ((v >> 1) & 0x55555555) | ((v & 0x55555555) << 1); \n");
562 source.append(
" v = ((v >> 2) & 0x33333333) | ((v & 0x33333333) << 2); \n");
563 source.append(
" v = ((v >> 4) & 0x0F0F0F0F) | ((v & 0x0F0F0F0F) << 4); \n");
564 source.append(
" v = ((v >> 8) & 0x00FF00FF) | ((v & 0x00FF00FF) << 8); \n");
565 source.append(
" v = (v >> 16) | (v << 16); \n");
567 source.append(
" v = v >> (32 - bit_size); \n");
569 source.append(
" return v; \n");
570 source.append(
"} \n");
572 source.append(
"__kernel void fft_reorder(__global "); source.append(numeric_string); source.append(
"2* input, \n");
573 source.append(
" unsigned int bit_size, \n");
574 source.append(
" unsigned int size, \n");
575 source.append(
" unsigned int stride, \n");
576 source.append(
" int batch_num) { \n");
578 source.append(
" unsigned int glb_id = get_global_id(0); \n");
579 source.append(
" unsigned int glb_sz = get_global_size(0); \n");
581 source.append(
" for(unsigned int batch_id = 0; batch_id < batch_num; batch_id++) { \n");
582 source.append(
" for(unsigned int i = glb_id; i < size; i += glb_sz) { \n");
583 source.append(
" unsigned int v = get_reorder_num_2(i, bit_size); \n");
585 source.append(
" if(i < v) {\n");
588 source.append(
" "); source.append(numeric_string); source.append(
"2 tmp = input[batch_id * stride + i]; \n");
589 source.append(
" input[batch_id * stride + i] = input[batch_id * stride + v]; \n");
590 source.append(
" input[batch_id * stride + v] = tmp; \n");
594 source.append(
" "); source.append(numeric_string); source.append(
"2 tmp = input[i * stride + batch_id]; \n");
595 source.append(
" input[i * stride + batch_id] = input[v * stride + batch_id]; \n");
596 source.append(
" input[v * stride + batch_id] = tmp; \n");
598 source.append(
" } \n");
599 source.append(
" } \n");
600 source.append(
" } \n");
601 source.append(
"} \n");
604 template <
typename StringType>
607 source.append(
"__kernel void lu_factorize( \n");
608 source.append(
" __global "); source.append(numeric_string); source.append(
" * matrix, \n");
609 source.append(
" unsigned int matrix_rows, \n");
610 source.append(
" unsigned int matrix_cols, \n");
611 source.append(
" unsigned int matrix_internal_rows, \n");
612 source.append(
" unsigned int matrix_internal_cols) \n");
613 source.append(
"{ \n");
614 source.append(
" "); source.append(numeric_string); source.append(
" temp; \n");
618 source.append(
" unsigned rowi; \n");
619 source.append(
" unsigned rowk; \n");
620 source.append(
" for (unsigned int i=1; i<matrix_rows; ++i) \n");
621 source.append(
" { \n");
622 source.append(
" rowi = i * matrix_internal_cols; \n");
623 source.append(
" for (unsigned int k=0; k<i; ++k) \n");
624 source.append(
" { \n");
625 source.append(
" rowk = k * matrix_internal_cols; \n");
626 source.append(
" if (get_global_id(0) == 0) \n");
627 source.append(
" matrix[rowi + k] /= matrix[rowk + k]; \n");
629 source.append(
" barrier(CLK_GLOBAL_MEM_FENCE); \n");
630 source.append(
" temp = matrix[rowi + k]; \n");
633 source.append(
" for (unsigned int j=k+1 + get_global_id(0); j<matrix_rows; j += get_global_size(0)) \n");
634 source.append(
" matrix[rowi + j] -= temp * matrix[rowk + j]; \n");
638 source.append(
" for (unsigned int i=1; i<matrix_rows; ++i) \n");
639 source.append(
" { \n");
640 source.append(
" for (unsigned int k=0; k<i; ++k) \n");
641 source.append(
" { \n");
643 source.append(
" if (get_global_id(0) == 0) \n");
644 source.append(
" matrix[i + k*matrix_internal_rows] /= matrix[k + k*matrix_internal_rows]; \n");
646 source.append(
" barrier(CLK_GLOBAL_MEM_FENCE); \n");
647 source.append(
" temp = matrix[i + k*matrix_internal_rows]; \n");
650 source.append(
" for (unsigned int j=k+1 + get_global_id(0); j<matrix_cols; j += get_global_size(0)) \n");
651 source.append(
" matrix[i + j*matrix_internal_rows] -= temp * matrix[k + j*matrix_internal_rows]; \n");
659 template <
typename StringType>
662 source.append(
"__kernel void scaled_rank1_update_"); alpha_on_cpu ? source.append(
"cpu") : source.append(
"gpu"); source.append(
"( \n");
663 source.append(
" __global "); source.append(numeric_string); source.append(
" * A, \n");
664 source.append(
" unsigned int A_start1, unsigned int A_start2, \n");
665 source.append(
" unsigned int A_inc1, unsigned int A_inc2, \n");
666 source.append(
" unsigned int A_size1, unsigned int A_size2, \n");
667 source.append(
" unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
670 source.append(
" "); source.append(numeric_string); source.append(
" val, \n");
672 source.append(
" __global const "); source.append(numeric_string); source.append(
" *val, \n");
674 source.append(
" unsigned int options2, \n");
676 source.append(
" __global const "); source.append(numeric_string); source.append(
" * vec1, \n");
677 source.append(
" unsigned int start1, \n");
678 source.append(
" unsigned int inc1, \n");
679 source.append(
" unsigned int size1, \n");
681 source.append(
" __global const "); source.append(numeric_string); source.append(
" * vec2, \n");
682 source.append(
" unsigned int start2, \n");
683 source.append(
" unsigned int inc2, \n");
684 source.append(
" unsigned int size2) \n");
685 source.append(
"{ \n");
688 source.append(
" "); source.append(numeric_string); source.append(
" alpha = val; \n");
690 source.append(
" "); source.append(numeric_string); source.append(
" alpha = val[0]; \n");
692 source.append(
" if (options2 & (1 << 0)) \n");
693 source.append(
" alpha = -alpha; \n");
695 source.append(
" unsigned int row_gid = get_global_id(0) / get_local_size(0); \n");
696 source.append(
" unsigned int col_gid = get_global_id(0) % get_local_size(0); \n");
698 source.append(
" for (unsigned int row = row_gid; row < A_size1; row += get_num_groups(0)) \n");
699 source.append(
" { \n");
700 source.append(
" "); source.append(numeric_string); source.append(
" tmp = vec1[row * inc1 + start1];");
701 source.append(
" tmp = (options2 & (1 << 1)) ? tmp / alpha : tmp * alpha;");
702 source.append(
" for (unsigned int col = col_gid; col < A_size2; col += get_local_size(0)) \n");
704 source.append(
" A[(row * A_inc1 + A_start1) * A_internal_size2 + col * A_inc2 + A_start2] += tmp * vec2[col * inc2 + start2]; \n");
706 source.append(
" A[(row * A_inc1 + A_start1) + (col * A_inc2 + A_start2) * A_internal_size1] += tmp * vec2[col * inc2 + start2]; \n");
707 source.append(
" } \n");
708 source.append(
"} \n");
711 template <
typename StringType>
714 source.append(
"__kernel void trans_vec_mul( \n");
715 source.append(
" __global const "); source.append(numeric_string); source.append(
" * A, \n");
716 source.append(
" unsigned int A_row_start, unsigned int A_col_start, \n");
717 source.append(
" unsigned int A_row_inc, unsigned int A_col_inc, \n");
718 source.append(
" unsigned int A_row_size, unsigned int A_col_size, \n");
719 source.append(
" unsigned int A_internal_rows, unsigned int A_internal_cols, \n");
720 source.append(
" __global const "); source.append(numeric_string); source.append(
" * v, \n");
721 source.append(
" unsigned int v_start, unsigned int v_inc, unsigned int v_size, \n");
722 source.append(
" __global "); source.append(numeric_string); source.append(
" * result, \n");
723 source.append(
" unsigned int result_start, unsigned int result_inc, unsigned int result_size, \n");
724 source.append(
" __local "); source.append(numeric_string); source.append(
" * work) \n");
725 source.append(
"{ \n");
728 source.append(
" for (unsigned int row = get_global_id(0); row < A_col_size; row += get_global_size(0)) \n");
729 source.append(
" { \n");
730 source.append(
" "); source.append(numeric_string); source.append(
" dot_prod = 0; \n");
731 source.append(
" for (unsigned int col = 0; col < A_row_size; ++col) \n");
732 source.append(
" dot_prod += A[(row * A_col_inc + A_col_start) + (col * A_row_inc + A_row_start) * A_internal_cols] * v[v_start + v_inc * col]; \n");
733 source.append(
" result[row * result_inc + result_start] = dot_prod; \n");
737 source.append(
" unsigned int row_gid = get_global_id(0) / get_local_size(0); \n");
738 source.append(
" unsigned int col_gid = get_global_id(0) % get_local_size(0); \n");
739 source.append(
" unsigned int lid = get_local_id(0); \n");
741 source.append(
" for (unsigned int row = row_gid; row < A_col_size; row += get_num_groups(0)) \n");
742 source.append(
" { \n");
743 source.append(
" "); source.append(numeric_string); source.append(
" dot_prod = 0; \n");
744 source.append(
" for (unsigned int col = col_gid; col < A_row_size; col+=get_local_size(0)) \n");
745 source.append(
" dot_prod += A[(row * A_col_inc + A_col_start) * A_internal_rows + col * A_row_inc + A_row_start] * v[v_start + v_inc * col]; \n");
746 source.append(
" work[lid] = dot_prod; \n");
748 source.append(
" for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1){ \n");
749 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
750 source.append(
" if(lid < stride) \n");
751 source.append(
" work[lid] += work[lid+stride]; \n");
752 source.append(
" } \n");
754 source.append(
" if(lid == 0) \n");
755 source.append(
" result[row * result_inc + result_start] = work[0]; \n");
757 source.append(
" } \n");
758 source.append(
"} \n");
761 template <
typename StringType>
764 source.append(
"__kernel void triangular_substitute_inplace( \n");
765 source.append(
" __global "); source.append(numeric_string); source.append(
" * A, \n");
766 source.append(
" unsigned int A_start1, unsigned int A_start2, \n");
767 source.append(
" unsigned int A_inc1, unsigned int A_inc2, \n");
768 source.append(
" unsigned int A_size1, unsigned int A_size2, \n");
769 source.append(
" unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
770 source.append(
" __global "); source.append(numeric_string); source.append(
" * v, \n");
771 source.append(
" unsigned int v_start, \n");
772 source.append(
" unsigned int v_inc, \n");
773 source.append(
" unsigned int v_size, \n");
774 source.append(
" unsigned int options) \n");
775 source.append(
"{ \n");
776 source.append(
" "); source.append(numeric_string); source.append(
" temp; \n");
777 source.append(
" unsigned int unit_diagonal_flag = (options & (1 << 0)); \n");
778 source.append(
" unsigned int transposed_access_A = (options & (1 << 1)); \n");
779 source.append(
" unsigned int is_lower_solve = (options & (1 << 2)); \n");
780 source.append(
" unsigned int row; \n");
781 source.append(
" for (unsigned int rows_processed = 0; rows_processed < A_size1; ++rows_processed) \n");
782 source.append(
" { \n");
783 source.append(
" row = is_lower_solve ? rows_processed : ((A_size1 - rows_processed) - 1); \n");
784 source.append(
" if (!unit_diagonal_flag) \n");
785 source.append(
" { \n");
786 source.append(
" barrier(CLK_GLOBAL_MEM_FENCE); \n");
787 source.append(
" if (get_global_id(0) == 0) \n");
789 source.append(
" v[row * v_inc + v_start] /= A[(row * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n");
791 source.append(
" v[row * v_inc + v_start] /= A[(row * A_inc1 + A_start1) + (row * A_inc2 + A_start2) * A_internal_size1]; \n");
792 source.append(
" } \n");
794 source.append(
" barrier(CLK_GLOBAL_MEM_FENCE); \n");
796 source.append(
" temp = v[row * v_inc + v_start]; \n");
798 source.append(
" for (int elim = (is_lower_solve ? (row + get_global_id(0) + 1) : get_global_id(0)); \n");
799 source.append(
" elim < (is_lower_solve ? A_size1 : row); \n");
800 source.append(
" elim += get_global_size(0)) \n");
803 source.append(
" v[elim * v_inc + v_start] -= temp * A[transposed_access_A ? ((row * A_inc1 + A_start1) * A_internal_size2 + (elim * A_inc2 + A_start2)) \n");
804 source.append(
" : ((elim * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2))]; \n");
808 source.append(
" v[elim * v_inc + v_start] -= temp * A[transposed_access_A ? ((row * A_inc1 + A_start1) + (elim * A_inc2 + A_start2) * A_internal_size1) \n");
809 source.append(
" : ((elim * A_inc1 + A_start1) + (row * A_inc2 + A_start2) * A_internal_size1)]; \n");
811 source.append(
" } \n");
812 source.append(
"} \n");
815 template <
typename StringType>
818 source.append(
"__kernel void vec_mul( \n");
819 source.append(
" __global const "); source.append(numeric_string); source.append(
" * A, \n");
820 source.append(
" unsigned int A_row_start, unsigned int A_col_start, \n");
821 source.append(
" unsigned int A_row_inc, unsigned int A_col_inc, \n");
822 source.append(
" unsigned int A_row_size, unsigned int A_col_size, \n");
823 source.append(
" unsigned int A_internal_rows, unsigned int A_internal_cols, \n");
824 source.append(
" __global const "); source.append(numeric_string); source.append(
" * v, \n");
825 source.append(
" unsigned int v_start, unsigned int v_inc, unsigned int v_size, \n");
826 source.append(
" __global "); source.append(numeric_string); source.append(
" * result, \n");
827 source.append(
" unsigned int result_start, unsigned int result_inc, unsigned int result_size, \n");
828 source.append(
" __local "); source.append(numeric_string); source.append(
" * work) \n");
829 source.append(
"{ \n");
832 source.append(
" unsigned int row_gid = get_global_id(0) / get_local_size(0); \n");
833 source.append(
" unsigned int col_gid = get_global_id(0) % get_local_size(0); \n");
834 source.append(
" unsigned int lid = get_local_id(0); \n");
836 source.append(
" for (unsigned int row = row_gid; row < A_row_size; row += get_num_groups(0)) \n");
837 source.append(
" { \n");
838 source.append(
" "); source.append(numeric_string); source.append(
" dot_prod = 0; \n");
839 source.append(
" for (unsigned int col = col_gid; col < A_col_size; col+=get_local_size(0)) \n");
840 source.append(
" dot_prod += A[(row * A_row_inc + A_row_start) * A_internal_cols + col * A_col_inc + A_col_start] * v[v_start + v_inc * col]; \n");
841 source.append(
" work[lid] = dot_prod; \n");
843 source.append(
" for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1){ \n");
844 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
845 source.append(
" if(lid < stride) \n");
846 source.append(
" work[lid] += work[lid+stride]; \n");
847 source.append(
" } \n");
849 source.append(
" if(lid == 0) \n");
850 source.append(
" result[row * result_inc + result_start] = work[0]; \n");
855 source.append(
" for (unsigned int row = get_global_id(0); row < A_row_size; row += get_global_size(0)) \n");
856 source.append(
" { \n");
857 source.append(
" "); source.append(numeric_string); source.append(
" dot_prod = 0; \n");
858 source.append(
" for (unsigned int col = 0; col < A_col_size; ++col) \n");
859 source.append(
" dot_prod += A[(row * A_row_inc + A_row_start) + (col * A_col_inc + A_col_start) * A_internal_rows] * v[v_start + v_inc * col]; \n");
860 source.append(
" result[row * result_inc + result_start] = dot_prod; \n");
862 source.append(
" } \n");
863 source.append(
"} \n");
876 template <
typename NumericT,
typename F>
890 static std::map<cl_context, bool> init_done;
894 source.reserve(8192);
896 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
910 if (numeric_string ==
"float" || numeric_string ==
"double")
918 #ifdef VIENNACL_BUILD_INFO
919 std::cout <<
"Creating program " << prog_name << std::endl;
921 ctx.add_program(source, prog_name);
922 init_done[ctx.handle().get()] =
true;
bool is_row_major(viennacl::row_major_tag)
Definition: common.hpp:73
void generate_ambm_impl(StringType &source, std::string const &numeric_string, ambm_config const &cfg)
Definition: matrix.hpp:121
ambm_scalar_type a
Definition: matrix.hpp:38
Helper class for checking whether a matrix has a row-major layout.
Definition: forwards.h:399
Definition: matrix.hpp:26
bool with_stride_and_range
Definition: matrix.hpp:35
Manages an OpenCL context and provides the respective convenience functions for creating buffers...
Definition: context.hpp:51
void generate_assign_cpu(StringType &source, std::string const &numeric_string, bool is_row_major)
Definition: matrix.hpp:257
Provides OpenCL-related utilities.
void generate_fft(StringType &source, std::string const &numeric_string, bool is_row_major)
Definition: matrix.hpp:383
static void init(viennacl::ocl::context &ctx)
Definition: matrix.hpp:884
void generate_element_op(StringType &source, std::string const &numeric_string, bool is_row_major)
Definition: matrix.hpp:306
ambm_scalar_type
Enumeration for the scalar type in ambm-like operations.
Definition: matrix.hpp:23
Definition: matrix.hpp:27
const OCL_TYPE & get() const
Definition: handle.hpp:189
const viennacl::ocl::handle< cl_context > & handle() const
Returns the context handle.
Definition: context.hpp:476
Main namespace in ViennaCL. Holds all the basic types such as vector, matrix, etc. and defines operations upon them.
Definition: cpu_ram.hpp:29
void generate_ambm(StringType &source, std::string const &numeric_string, bool is_row_major)
Definition: matrix.hpp:230
void generate_scaled_rank1_update(StringType &source, std::string const &numeric_string, bool is_row_major, bool alpha_on_cpu)
Definition: matrix.hpp:660
std::string assign_op
Definition: matrix.hpp:37
bool is_row_major
Definition: matrix.hpp:36
void generate_trans_vec_mul(StringType &source, std::string const &numeric_string, bool is_row_major)
Definition: matrix.hpp:712
void generate_vec_mul(StringType &source, std::string const &numeric_string)
Definition: compressed_compressed_matrix.hpp:23
ambm_config()
Definition: matrix.hpp:33
static void apply(viennacl::ocl::context const &)
Definition: utils.hpp:40
Configuration struct for generating OpenCL kernels for linear combinations of matrices.
Definition: matrix.hpp:31
ambm_scalar_type b
Definition: matrix.hpp:39
void generate_ambm_impl2(StringType &source, ambm_config const &cfg, bool mult_alpha, bool mult_beta)
Definition: matrix.hpp:44
Main kernel class for generating OpenCL kernels for operations on/with dense matrix objects of type v...
Definition: matrix.hpp:877
Representation of an OpenCL kernel in ViennaCL.
std::string type_to_string(viennacl::row_major)
Definition: matrix.hpp:868
void generate_lu(StringType &source, std::string const &numeric_string, bool is_row_major)
Definition: matrix.hpp:605
A tag for column-major storage of a dense matrix.
Definition: forwards.h:263
static std::string program_name()
Definition: matrix.hpp:879
void generate_diagonal_assign_cpu(StringType &source, std::string const &numeric_string, bool is_row_major)
Definition: matrix.hpp:287
Definition: matrix.hpp:25
Helper class for converting a type to its string representation.
Definition: utils.hpp:57
A tag for row-major storage of a dense matrix.
Definition: forwards.h:246
void generate_triangular_substitute_inplace(StringType &source, std::string const &numeric_string, bool is_row_major)
Definition: matrix.hpp:762