SoftmaxExecution.hpp 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. //
  2. // SoftmaxExecution.hpp
  3. // MNN
  4. //
  5. // Created by MNN on 2019/01/31.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #ifndef SoftmaxExecution_hpp
  9. #define SoftmaxExecution_hpp
  10. #include <vector>
  11. #include "core/Execution.hpp"
  12. #include "backend/opencl/core/OpenCLBackend.hpp"
  13. #include "backend/opencl/core/OpenCLRunningUtils.hpp"
  14. #include "backend/opencl/execution/image/CommonExtension.hpp"
  15. namespace MNN {
  16. namespace OpenCL {
  17. class SoftmaxExecution : public Execution, public CommonExtension {
  18. public:
  19. SoftmaxExecution(const std::vector<Tensor *> &inputs, int axis, Backend *backend);
  20. virtual ~SoftmaxExecution() = default;
  21. virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
  22. virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
  23. bool buildSoftmaxKernel(int localSize);
  24. private:
  25. int getLocalSize(int size, int maxGroupSize);
  26. cl::Kernel mKernel;
  27. uint32_t mMaxWorkGroupSize;
  28. OpenCLBackend *mOpenCLBackend;
  29. std::vector<uint32_t> mGlobalWorkSize{1, 1, 1};
  30. std::vector<uint32_t> mLocalWorkSize{1, 1, 1, 1};
  31. int mAxis;
  32. };
  33. } // namespace OpenCL
  34. } // namespace MNN
  35. #endif /* SoftmaxExecution_hpp */