Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/api/c/print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ af_err af_array_to_string(char **output, const char *exp, const af_array arr,
case u16:
print<ushort>(exp, arr, precision, ss, transpose);
break;
case f16:
print<half>(exp, arr, precision, ss, transpose);
break;
default: TYPE_ERROR(1, type);
}
}
Expand Down
17 changes: 17 additions & 0 deletions src/backend/cuda/Array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ using std::shared_ptr;
using std::vector;

namespace cuda {

template<typename T>
void verifyTypeSupport() {
if ((std::is_same<T, double>::value || std::is_same<T, cdouble>::value) &&
!isDoubleSupported(getActiveDeviceId())) {
AF_ERROR("Double precision not supported", AF_ERR_NO_DBL);
} else if (std::is_same<T, common::half>::value &&
!isHalfSupported(getActiveDeviceId())) {
AF_ERROR("Half precision not supported", AF_ERR_NO_HALF);
}
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move this to src/backend/common or src/api/c since this can be used from other locations too.

Also, Don't we need similar checks in create*Array helpers in others backends too ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already done in the OpenCL backend. It was not done in this backend because I intended to support half on older hardware and forgot to put it back in once I abandoned that idea.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, lets try to move this check into src/api/c/ a more common location is perhaps af_create_array ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots of functions create arrays. for example randu. I think implementing it in the common namespace is sufficient.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whichever location makes sense most.


template<typename T>
Node_ptr bufferNodePtr() {
return Node_ptr(new BufferNode<T>(getFullName<T>(), shortname<T>(true)));
Expand Down Expand Up @@ -302,31 +314,36 @@ kJITHeuristics passesJitHeuristics(Node *root_node) {

template<typename T>
Array<T> createNodeArray(const dim4 &dims, Node_ptr node) {
verifyTypeSupport<T>();
Array<T> out = Array<T>(dims, node);
return out;
}

template<typename T>
Array<T> createHostDataArray(const dim4 &dims, const T *const data) {
verifyTypeSupport<T>();
bool is_device = false;
bool copy_device = false;
return Array<T>(dims, data, is_device, copy_device);
}

template<typename T>
Array<T> createDeviceDataArray(const dim4 &dims, void *data) {
verifyTypeSupport<T>();
bool is_device = true;
bool copy_device = false;
return Array<T>(dims, static_cast<T *>(data), is_device, copy_device);
}

template<typename T>
Array<T> createValueArray(const dim4 &dims, const T &value) {
verifyTypeSupport<T>();
return createScalarNode<T>(dims, value);
}

template<typename T>
Array<T> createEmptyArray(const dim4 &dims) {
verifyTypeSupport<T>();
return Array<T>(dims);
}

Expand Down
14 changes: 11 additions & 3 deletions src/backend/cuda/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,17 @@ bool isDoubleSupported(int device) {
}

bool isHalfSupported(int device) {
auto prop = getDeviceProp(device);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't getDeviceProp inexpensive ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but this will be slightly faster.

float compute = prop.major * 1000 + prop.minor * 10;
return compute >= 5030;
std::array<bool, DeviceManager::MAX_DEVICES> half_supported = []() {
std::array<bool, DeviceManager::MAX_DEVICES> out;
int count = getDeviceCount();
for (int i = 0; i < count; i++) {
auto prop = getDeviceProp(i);
float compute = prop.major * 1000 + prop.minor * 10;
out[i] = compute >= 5030;
}
return out;
}();
return half_supported[device];
}

void devprop(char *d_name, char *d_platform, char *d_toolkit, char *d_compute) {
Expand Down