CUDARuntime.hpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. //
  2. // CUDARuntime.hpp
  3. // MNN
  4. //
  5. // Created by MNN on 2019/01/31.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #ifndef OpenCLRuntime_hpp
  9. #define OpenCLRuntime_hpp
  10. #include <map>
  11. #include <memory>
  12. #include <mutex>
  13. #include <set>
  14. #include <string>
  15. #include <vector>
  16. #include <cuda.h>
  17. #include <cuda_runtime_api.h>
  18. #include <cusolverDn.h>
  19. #include <sstream>
  20. #include <string>
  21. #include <vector>
  22. #include "Type_generated.h"
  23. #include "core/Macro.h"
  24. typedef enum {
  25. CUDA_FLOAT32 = 0,
  26. CUDA_FLOAT16 = 1,
  27. } MNNCUDADataType_t;
  28. typedef enum {
  29. MNNMemcpyHostToDevice = 1,
  30. MNNMemcpyDeviceToHost = 2,
  31. MNNMemcpyDeviceToDevice = 3,
  32. } MNNMemcpyKind_t;
  33. #define cuda_check(_x) \
  34. do { \
  35. cudaError_t _err = (_x); \
  36. if (_err != cudaSuccess) { \
  37. MNN_CHECK(_err, #_x); \
  38. } \
  39. } while (0)
  40. #define after_kernel_launch() \
  41. do { \
  42. cuda_check(cudaGetLastError()); \
  43. } while (0)
  44. #ifdef DEBUG
  45. #define checkKernelErrors\
  46. do { \
  47. cudaDeviceSynchronize();\
  48. cudaError_t __err = cudaGetLastError(); \
  49. if (__err != cudaSuccess) { \
  50. printf("File:%s Line %d: failed: %s\n", __FILE__, __LINE__,\
  51. cudaGetErrorString(__err)); \
  52. abort(); \
  53. } \
  54. } while (0)
  55. #define cutlass_check(status) \
  56. { \
  57. cutlass::Status error = status; \
  58. if (error != cutlass::Status::kSuccess) { \
  59. printf("File:%s Line %d: failed: %s\n", __FILE__, __LINE__,\
  60. cutlassGetStatusString(error)); \
  61. abort(); \
  62. } \
  63. }
  64. #else
  65. #define checkKernelErrors
  66. #define cutlass_check
  67. #endif
  68. namespace MNN {
  69. class CUDARuntime {
  70. public:
  71. CUDARuntime(int device_id);
  72. ~CUDARuntime();
  73. CUDARuntime(const CUDARuntime &) = delete;
  74. CUDARuntime &operator=(const CUDARuntime &) = delete;
  75. bool isSupportedFP16() const;
  76. bool isSupportedDotInt8() const;
  77. bool isSupportedDotAccInt8() const;
  78. std::vector<size_t> getMaxImage2DSize();
  79. bool isCreateError() const;
  80. float flops() const {
  81. return mFlops;
  82. }
  83. int device_id() const;
  84. size_t mem_alignment_in_bytes() const;
  85. void activate();
  86. void *alloc(size_t size_in_bytes);
  87. void free(void *ptr);
  88. void memcpy(void *dst, const void *src, size_t size_in_bytes, MNNMemcpyKind_t kind, bool sync = false);
  89. void memset(void *dst, int value, size_t size_in_bytes);
  90. size_t threads_num() {
  91. return mThreadPerBlock;
  92. }
  93. const cudaDeviceProp& prop() const {
  94. return mProp;
  95. }
  96. int major_sm() const {
  97. return mProp.major;
  98. }
  99. int compute_capability() {
  100. return mProp.major * 10 + mProp.minor;
  101. }
  102. size_t blocks_num(const size_t total_threads);
  103. const int smemPerBlock() {
  104. return mProp.sharedMemPerBlock;
  105. }
  106. int selectDeviceMaxFreeMemory();
  107. private:
  108. cudaDeviceProp mProp;
  109. int mDeviceId;
  110. bool mIsSupportedFP16 = false;
  111. bool mSupportDotInt8 = false;
  112. bool mSupportDotAccInt8 = false;
  113. float mFlops = 4.0f;
  114. bool mIsCreateError{false};
  115. size_t mThreadPerBlock = 128;
  116. };
  117. } // namespace MNN
  118. #endif /* CUDARuntime_hpp */