1 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_SOLVE_HPP
2 #define VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_SOLVE_HPP
22 template <
typename StringType>
24 bool row_major_A,
bool row_major_B,
25 bool transpose_A,
bool transpose_B,
26 bool upper_solve,
bool unit_diagonal)
29 source.append(
"__kernel void ");
31 source.append(
"trans_");
33 source.append(
"unit_");
35 source.append(
"upper_");
37 source.append(
"lower_");
39 source.append(
"trans_");
40 source.append(
"solve");
42 source.append(
"( \n");
43 source.append(
" __global const "); source.append(numeric_string); source.append(
" * A, \n");
44 source.append(
" unsigned int A_start1, unsigned int A_start2, \n");
45 source.append(
" unsigned int A_inc1, unsigned int A_inc2, \n");
46 source.append(
" unsigned int A_size1, unsigned int A_size2, \n");
47 source.append(
" unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
48 source.append(
" __global "); source.append(numeric_string); source.append(
" * B, \n");
49 source.append(
" unsigned int B_start1, unsigned int B_start2, \n");
50 source.append(
" unsigned int B_inc1, unsigned int B_inc2, \n");
51 source.append(
" unsigned int B_size1, unsigned int B_size2, \n");
52 source.append(
" unsigned int B_internal_size1, unsigned int B_internal_size2) { \n");
53 source.append(
" "); source.append(numeric_string); source.append(
" temp; \n");
57 source.append(
" for (unsigned int row_cnt = 0; row_cnt < A_size1; ++row_cnt) \n");
58 source.append(
" { \n");
59 source.append(
" unsigned int row = A_size1 - 1 - row_cnt; \n");
63 source.append(
" for (unsigned int row = 0; row < A_size1; ++row) \n");
64 source.append(
" { \n");
69 source.append(
" barrier(CLK_GLOBAL_MEM_FENCE); \n");
70 source.append(
" if (get_local_id(0) == 0) \n");
72 if (row_major_B && transpose_B)
73 source.append(
" B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (row * B_inc2 + B_start2)] /= ");
74 else if (row_major_B && !transpose_B)
75 source.append(
" B[(row * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)] /= ");
76 else if (!row_major_B && transpose_B)
77 source.append(
" B[(get_group_id(0) * B_inc1 + B_start1) + (row * B_inc2 + B_start2) * B_internal_size1] /= ");
78 else if (!row_major_B && !transpose_B)
79 source.append(
" B[(row * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1] /= ");
82 source.append(
"A[(row * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n");
84 source.append(
"A[(row * A_inc1 + A_start1) + (row * A_inc2 + A_start2)*A_internal_size1]; \n");
87 source.append(
" barrier(CLK_GLOBAL_MEM_FENCE); \n");
89 if (row_major_B && transpose_B)
90 source.append(
" temp = B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (row * B_inc2 + B_start2)]; \n");
91 else if (row_major_B && !transpose_B)
92 source.append(
" temp = B[(row * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)]; \n");
93 else if (!row_major_B && transpose_B)
94 source.append(
" temp = B[(get_group_id(0) * B_inc1 + B_start1) + (row * B_inc2 + B_start2) * B_internal_size1]; \n");
95 else if (!row_major_B && !transpose_B)
96 source.append(
" temp = B[(row * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1]; \n");
98 source.append(
" //eliminate column of op(A) with index 'row' in parallel: \n");
100 source.append(
" for (unsigned int elim = get_local_id(0); elim < row; elim += get_local_size(0)) \n");
102 source.append(
" for (unsigned int elim = row + get_local_id(0) + 1; elim < A_size1; elim += get_local_size(0)) \n");
104 if (row_major_B && transpose_B)
105 source.append(
" B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (elim * B_inc2 + B_start2)] -= temp * ");
106 else if (row_major_B && !transpose_B)
107 source.append(
" B[(elim * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)] -= temp * ");
108 else if (!row_major_B && transpose_B)
109 source.append(
" B[(get_group_id(0) * B_inc1 + B_start1) + (elim * B_inc2 + B_start2) * B_internal_size1] -= temp * ");
110 else if (!row_major_B && !transpose_B)
111 source.append(
" B[(elim * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1] -= temp * ");
113 if (row_major_A && transpose_A)
114 source.append(
"A[(row * A_inc1 + A_start1) * A_internal_size2 + (elim * A_inc2 + A_start2)]; \n");
115 else if (row_major_A && !transpose_A)
116 source.append(
"A[(elim * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n");
117 else if (!row_major_A && transpose_A)
118 source.append(
"A[(row * A_inc1 + A_start1) + (elim * A_inc2 + A_start2) * A_internal_size1]; \n");
119 else if (!row_major_A && !transpose_A)
120 source.append(
"A[(elim * A_inc1 + A_start1) + (row * A_inc2 + A_start2) * A_internal_size1]; \n");
122 source.append(
" } \n");
123 source.append(
"} \n");
133 template <
class NumericT,
typename F1,
typename F2>
149 static std::map<cl_context, bool> init_done;
153 source.reserve(8192);
155 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
158 if (numeric_string ==
"float" || numeric_string ==
"double")
161 false,
false,
false,
false);
163 false,
false,
false,
true);
165 false,
false,
true,
false);
167 false,
false,
true,
true);
170 false,
true,
false,
false);
172 false,
true,
false,
true);
174 false,
true,
true,
false);
176 false,
true,
true,
true);
179 true,
false,
false,
false);
181 true,
false,
false,
true);
183 true,
false,
true,
false);
185 true,
false,
true,
true);
188 true,
true,
false,
false);
190 true,
true,
false,
true);
192 true,
true,
true,
false);
194 true,
true,
true,
true);
198 #ifdef VIENNACL_BUILD_INFO
199 std::cout <<
"Creating program " << prog_name << std::endl;
201 ctx.add_program(source, prog_name);
202 init_done[ctx.handle().get()] =
true;
Helper class for checking whether a matrix has a row-major layout.
Definition: forwards.h:399
Manages an OpenCL context and provides the respective convenience functions for creating buffers...
Definition: context.hpp:51
Provides OpenCL-related utilities.
const OCL_TYPE & get() const
Definition: handle.hpp:189
const viennacl::ocl::handle< cl_context > & handle() const
Returns the context handle.
Definition: context.hpp:476
Main namespace in ViennaCL. Holds all the basic types such as vector, matrix, etc. and defines operations upon them.
Definition: cpu_ram.hpp:29
void generate_matrix_solve_blas3(StringType &source, std::string const &numeric_string, bool row_major_A, bool row_major_B, bool transpose_A, bool transpose_B, bool upper_solve, bool unit_diagonal)
Definition: matrix_solve.hpp:23
Main kernel class for the generation of matrix solve kernels.
Definition: matrix_solve.hpp:134
static void apply(viennacl::ocl::context const &)
Definition: utils.hpp:40
Representation of an OpenCL kernel in ViennaCL.
std::string type_to_string(viennacl::row_major)
Definition: matrix.hpp:868
static void init(viennacl::ocl::context &ctx)
Definition: matrix_solve.hpp:141
static std::string program_name()
Definition: matrix_solve.hpp:136
Helper class for converting a type to its string representation.
Definition: utils.hpp:57
Runtime generation of OpenCL kernels for matrix operations.