00001 #ifndef VIENNACL_DEVICE_SPECIFIC_TEMPLATES_VECTOR_AXPY_HPP
00002 #define VIENNACL_DEVICE_SPECIFIC_TEMPLATES_VECTOR_AXPY_HPP
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00027 #include <vector>
00028 #include <cmath>
00029
00030 #include "viennacl/scheduler/forwards.h"
00031
00032 #include "viennacl/device_specific/mapped_objects.hpp"
00033 #include "viennacl/device_specific/tree_parsing.hpp"
00034 #include "viennacl/device_specific/forwards.h"
00035 #include "viennacl/device_specific/utils.hpp"
00036
00037 #include "viennacl/device_specific/templates/template_base.hpp"
00038 #include "viennacl/device_specific/templates/utils.hpp"
00039
00040 #include "viennacl/tools/tools.hpp"
00041
00042 namespace viennacl
00043 {
00044 namespace device_specific
00045 {
00046
00047 class vector_axpy_parameters : public template_base::parameters_type
00048 {
00049 public:
00050 vector_axpy_parameters(unsigned int _simd_width,
00051 unsigned int _group_size, unsigned int _num_groups,
00052 fetching_policy_type _fetching_policy) : template_base::parameters_type(_simd_width, _group_size, 1, 1), num_groups(_num_groups), fetching_policy(_fetching_policy){ }
00053
00054
00055
00056 unsigned int num_groups;
00057 fetching_policy_type fetching_policy;
00058 };
00059
00060 class vector_axpy_template : public template_base_impl<vector_axpy_template, vector_axpy_parameters>
00061 {
00062 private:
00063 virtual int check_invalid_impl(viennacl::ocl::device const & ) const
00064 {
00065 if (p_.fetching_policy==FETCH_FROM_LOCAL)
00066 return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
00067 return TEMPLATE_VALID;
00068 }
00069
00070 std::vector<std::string> generate_impl(std::string const & kernel_prefix, statements_container const & statements, std::vector<mapping_type> const & mappings) const
00071 {
00072 std::vector<std::string> result;
00073 for (unsigned int i = 0; i < 2; ++i)
00074 {
00075 utils::kernel_generation_stream stream;
00076 unsigned int simd_width = (i==0)?1:p_.simd_width;
00077 std::string suffix = (i==0)?"_strided":"";
00078 stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << ",1,1)))" << std::endl;
00079 generate_prototype(stream, kernel_prefix + suffix, "unsigned int N,", mappings, statements);
00080 stream << "{" << std::endl;
00081 stream.inc_tab();
00082
00083 tree_parsing::process(stream, PARENT_NODE_TYPE, "scalar", "#scalartype #namereg = *#pointer;", statements, mappings);
00084 tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#pointer += $OFFSET{#start1, #start2};", statements, mappings);
00085 tree_parsing::process(stream, PARENT_NODE_TYPE, "vector", "#pointer += #start;", statements, mappings);
00086
00087 struct loop_body : public loop_body_base
00088 {
00089 loop_body(statements_container const & statements_, std::vector<mapping_type> const & mappings_) : statements(statements_), mappings(mappings_) { }
00090
00091 void operator()(utils::kernel_generation_stream & kernel_stream, unsigned int kernel_simd_width) const
00092 {
00093 std::string process_str;
00094 std::string i_str = (kernel_simd_width==1)?"i*#stride":"i";
00095
00096 process_str = utils::append_width("#scalartype",kernel_simd_width) + " #namereg = " + vload(kernel_simd_width, i_str, "#pointer") + ";";
00097 tree_parsing::process(kernel_stream, PARENT_NODE_TYPE, "vector", process_str, statements, mappings);
00098 tree_parsing::process(kernel_stream, PARENT_NODE_TYPE, "matrix_row", "#scalartype #namereg = #pointer[$OFFSET{#row*#stride1, i*#stride2}];", statements, mappings);
00099 tree_parsing::process(kernel_stream, PARENT_NODE_TYPE, "matrix_column", "#scalartype #namereg = #pointer[$OFFSET{i*#stride1,#column*#stride2}];", statements, mappings);
00100 tree_parsing::process(kernel_stream, PARENT_NODE_TYPE, "matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride1, i*#stride2}:$OFFSET{i*#stride1, (i + #diag_offset)*#stride2}];", statements, mappings);
00101
00102 std::map<std::string, std::string> accessors;
00103 accessors["vector"] = "#namereg";
00104 accessors["matrix_row"] = "#namereg";
00105 accessors["matrix_column"] = "#namereg";
00106 accessors["matrix_diag"] = "#namereg";
00107 accessors["scalar"] = "#namereg";
00108 tree_parsing::evaluate(kernel_stream, PARENT_NODE_TYPE, accessors, statements, mappings);
00109
00110 process_str = vstore(kernel_simd_width, "#namereg",i_str,"#pointer")+";";
00111 tree_parsing::process(kernel_stream, LHS_NODE_TYPE, "vector", process_str, statements, mappings);
00112 tree_parsing::process(kernel_stream, LHS_NODE_TYPE, "matrix_row", "#pointer[$OFFSET{#row, i}] = #namereg;", statements, mappings);
00113 tree_parsing::process(kernel_stream, LHS_NODE_TYPE, "matrix_column", "#pointer[$OFFSET{i, #column}] = #namereg;", statements, mappings);
00114 tree_parsing::process(kernel_stream, LHS_NODE_TYPE, "matrix_diag", "#pointer[#diag_offset<0?$OFFSET{i - #diag_offset, i}:$OFFSET{i, i + #diag_offset}] = #namereg;", statements, mappings);
00115
00116 }
00117
00118 private:
00119 statements_container const & statements;
00120 std::vector<mapping_type> const & mappings;
00121 };
00122
00123 element_wise_loop_1D(stream, loop_body(statements, mappings), p_.fetching_policy, simd_width, "i", "N", "get_global_id(0)", "get_global_size(0)");
00124
00125 stream.dec_tab();
00126 stream << "}" << std::endl;
00127 result.push_back(stream.str());
00128 }
00129
00130 return result;
00131 }
00132
00133 public:
00134 vector_axpy_template(vector_axpy_template::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE) : template_base_impl<vector_axpy_template, vector_axpy_parameters>(parameters, binding_policy), up_to_internal_size_(false){ }
00135
00136 void up_to_internal_size(bool v) { up_to_internal_size_ = v; }
00137
00138 void enqueue(std::string const & kernel_prefix, std::vector<lazy_program_compiler> & programs, statements_container const & statements)
00139 {
00140 viennacl::ocl::kernel * kernel;
00141 if (has_strided_access(statements) && p_.simd_width > 1)
00142 kernel = &programs[0].program().get_kernel(kernel_prefix+"_strided");
00143 else
00144 kernel = &programs[1].program().get_kernel(kernel_prefix);
00145
00146 kernel->local_work_size(0, p_.local_size_0);
00147 kernel->global_work_size(0, p_.local_size_0*p_.num_groups);
00148 unsigned int current_arg = 0;
00149 scheduler::statement const & statement = statements.data().front();
00150 cl_uint size = static_cast<cl_uint>(vector_size(lhs_most(statement.array(), statement.root()), up_to_internal_size_));
00151 kernel->arg(current_arg++, size);
00152 set_arguments(statements, *kernel, current_arg);
00153 viennacl::ocl::enqueue(*kernel);
00154 }
00155
00156 private:
00157 bool up_to_internal_size_;
00158 };
00159
00160 }
00161 }
00162
00163 #endif