Skip to content

Commit

Permalink
Fix SYCL accessor subrange calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Aug 18, 2021
1 parent 58c1114 commit 0e28244
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
7 changes: 6 additions & 1 deletion include/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ namespace detail {

template <typename DataT, int Dims, cl::sycl::access::mode Mode, cl::sycl::target Target, int Index>
class accessor_subscript_proxy;

struct accessor_testspy;

} // namespace detail

/**
Expand Down Expand Up @@ -129,6 +132,8 @@ class host_memory_layout {
*/
template <typename DataT, int Dims, cl::sycl::access_mode Mode>
class accessor<DataT, Dims, Mode, cl::sycl::target::device> : public detail::accessor_base<DataT, Dims, Mode, cl::sycl::target::device> {
friend struct detail::accessor_testspy;

public:
accessor(const accessor& other) : sycl_accessor(other.sycl_accessor) { init_from(other); }

Expand All @@ -151,7 +156,7 @@ class accessor<DataT, Dims, Mode, cl::sycl::target::device> : public detail::acc
auto access_info = detail::runtime::get_instance().get_buffer_manager().get_device_buffer<DataT, Dims>(
detail::get_buffer_id(buff), Mode, detail::range_cast<3>(sr.range), detail::id_cast<3>(sr.offset));
eventual_sycl_cgh = live_cgh.get_eventual_sycl_cgh();
sycl_accessor = sycl_accessor_t(access_info.buffer, buff.get_range(), access_info.offset);
sycl_accessor = sycl_accessor_t(access_info.buffer, sr.range, sr.offset - access_info.offset);
backing_buffer_offset = access_info.offset;
}
}
Expand Down
2 changes: 2 additions & 0 deletions include/buffer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ namespace detail {
* FIXME: The current buffer locking mechanism limits task parallelism. Come up with a better solution.
*/
class buffer_manager {
friend struct buffer_manager_testspy;

public:
enum class buffer_lifecycle_event { REGISTERED, UNREGISTERED };

Expand Down
43 changes: 43 additions & 0 deletions test/runtime_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ namespace detail {

GridRegion<3> make_grid_region(cl::sycl::range<3> range, cl::sycl::id<3> offset = {}) { return GridRegion<3>(make_grid_box(range, offset)); }

struct accessor_testspy {
template <typename CelerityAccessor>
static auto& get_sycl_accessor(CelerityAccessor& celerity_acc) {
return celerity_acc.sycl_accessor;
}
};

struct buffer_manager_testspy {
template <typename DataT, int Dims>
static buffer_manager::access_info<DataT, Dims, device_buffer> get_device_buffer(buffer_manager& bm, buffer_id bid) {
std::unique_lock lock(bm.mutex);
auto& buf = bm.buffers.at(bid).device_buf;
return {dynamic_cast<device_buffer_storage<DataT, Dims>*>(buf.storage.get())->get_device_buffer(), id_cast<Dims>(buf.offset)};
}
};

TEST_CASE("only a single distr_queue can be created", "[distr_queue][lifetime][dx]") {
distr_queue q1;
auto q2{q1}; // Copying is allowed
Expand Down Expand Up @@ -2167,5 +2183,32 @@ namespace detail {
}));
}

TEST_CASE("SYCL accessors receive correct backing-buffer relative ranges and offsets", "[accessor]") {
distr_queue q;
buffer<int, 3> virtual_buf{cl::sycl::range<3>{1000, 1000, 1000}};
subrange<3> large_accessor_sr{{117, 118, 119}, {301, 302, 303}};
subrange<3> small_accessor_sr{{207, 206, 205}, {101, 102, 103}};

q.submit([=](handler& cgh) {
accessor large_celerity_acc{virtual_buf, cgh, celerity::access::fixed{large_accessor_sr}, cl::sycl::read_write};
accessor small_celerity_acc{virtual_buf, cgh, celerity::access::fixed{small_accessor_sr}, cl::sycl::read_write};
if(!is_prepass_handler(cgh)) {
auto& bm = runtime::get_instance().get_buffer_manager();
auto info = buffer_manager_testspy::get_device_buffer<int, 3>(bm, get_buffer_id(virtual_buf));
subrange<3> backing_buffer_sr{info.offset, info.buffer.get_range()};

auto& large_sycl_acc = accessor_testspy::get_sycl_accessor(large_celerity_acc);
auto& small_sycl_acc = accessor_testspy::get_sycl_accessor(small_celerity_acc);

CHECK(large_sycl_acc.get_range() == large_accessor_sr.range);
CHECK(small_sycl_acc.get_range() == small_accessor_sr.range);
CHECK(large_sycl_acc.get_offset() == large_accessor_sr.offset - backing_buffer_sr.offset);
CHECK(small_sycl_acc.get_offset() == small_accessor_sr.offset - backing_buffer_sr.offset);
CHECK(small_sycl_acc.get_offset() == large_sycl_acc.get_offset() + (small_accessor_sr.offset - large_accessor_sr.offset));
}
cgh.parallel_for<class UKN(dummy)>(cl::sycl::range<3>{1, 1, 1}, [](cl::sycl::item<3>) {});
});
}

} // namespace detail
} // namespace celerity

0 comments on commit 0e28244

Please sign in to comment.