瀏覽代碼

Merge pull request #3723 from jules-ai/feature/specify-cpu

new feature to allow user to specify cpu
jxt1234 2 周之前
父節點
當前提交
80a917a1b0

+ 4 - 1
docs/inference/session.md

@@ -181,7 +181,10 @@ struct BackendConfig {
     PrecisionMode precision = Precision_Normal;
     
     /** user defined context */
-    void* sharedContext = nullptr;
+    union {
+        void* sharedContext = nullptr;
+        size_t flags; // Valid for CPU Backend
+    };
 };
 ```
 

+ 8 - 0
express/Executor.cpp

@@ -231,6 +231,14 @@ void Executor::RuntimeManager::setHint(Interpreter::HintMode mode, int value) {
         iter.second->setRuntimeHint(mInside->mContent->modes.runtimeHint);
     }
 }
+void Executor::RuntimeManager::setHint(Interpreter::HintMode mode, int* value, size_t size) {
+    mInside->mContent->modes.setHint(mode, value, size);
+    auto current = ExecutorScope::Current();
+    auto rt = current->getRuntime();
+    for (auto& iter : rt.first) {
+        iter.second->setRuntimeHint(mInside->mContent->modes.runtimeHint);
+    }
+}
 void Executor::RuntimeManager::setExternalPath(std::string path, int type) {
     mInside->mContent->modes.setExternalPath(path, type);
 }

+ 8 - 3
include/MNN/Interpreter.hpp

@@ -245,7 +245,10 @@ public:
         USE_CACHED_MMAP = 12,
         
         // Multi-Thread Load module, default is 0 (don't use other Thread)
-        INIT_THREAD_NUMBER = 13
+        INIT_THREAD_NUMBER = 13,
+
+        // CPU core ids
+        CPU_CORE_IDS = 14,
     };
 
     enum ExternalPathType {
@@ -280,10 +283,12 @@ public:
 
     /**
      * @brief The API shoud be called before create session.
-     * @param mode      Hint type
+     * @param hint      Hint type
      * @param value     Hint value
+     * @param size      Hint value size(when use a ptr)
      */
-    void setSessionHint(HintMode mode, int value);
+    void setSessionHint(HintMode hint, int value);
+    void setSessionHint(HintMode hint, int* value, size_t size);
 public:
     /**
      * @brief create runtimeInfo separately with schedule config.

+ 1 - 0
include/MNN/expr/Executor.hpp

@@ -126,6 +126,7 @@ public:
         friend class Executor;
         void setMode(Interpreter::SessionMode mode);
         void setHint(Interpreter::HintMode mode, int value);
+        void setHint(Interpreter::HintMode mode, int* value, size_t size);
         void setHintPtr(Interpreter::HintMode mode, void* value);
         bool getInfo(Interpreter::SessionInfoCode code, void* ptr);
         BackendConfig* getBnConfig();

+ 1 - 1
source/backend/arm82/Arm82Backend.cpp

@@ -24,7 +24,7 @@
 
 namespace MNN {
 
-Arm82Backend::Arm82Backend(const CPURuntime* runtime, BackendConfig::MemoryMode memory, int initThreadNumber) : CPUBackend(runtime, BackendConfig::Precision_Low, memory, MNN_FORWARD_CPU_EXTENSION, 0, initThreadNumber) {
+Arm82Backend::Arm82Backend(const CPURuntime* runtime, BackendConfig::MemoryMode memory) : CPUBackend(runtime, BackendConfig::Precision_Low, memory, MNN_FORWARD_CPU_EXTENSION, 0) {
     mCoreFunctions = Arm82Functions::get();
     mInt8CoreFunctions = Arm82Functions::getInt8();
 }

+ 1 - 1
source/backend/arm82/Arm82Backend.hpp

@@ -28,7 +28,7 @@ namespace MNN {
 class Arm82Backend : public CPUBackend {
 public:
     virtual ~Arm82Backend();
-    Arm82Backend(const CPURuntime* runtime, BackendConfig::MemoryMode memory, int initThreadNumber);
+    Arm82Backend(const CPURuntime* runtime, BackendConfig::MemoryMode memory);
     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
                                 const MNN::Op* op) override;
     virtual Backend::MemObj* onAcquire(const Tensor* nativeTensor, StorageType storageType) override;

+ 125 - 64
source/backend/cpu/CPUBackend.cpp

@@ -9,6 +9,7 @@
 #include "backend/cpu/CPUBackend.hpp"
 #include <cmath>
 #include <mutex>
+#include <unordered_map>
 #include "CPUResizeCache.hpp"
 #include "core/BufferAllocator.hpp"
 #include "CPUTensorConvert.hpp"
@@ -78,7 +79,7 @@ void CPUBackend::computeDivideSizes(int size, int* dst, float avgDiv) const {
 }
 
 void CPURuntime::_bindCPUCore() const {
-    if (mPower == BackendConfig::Power_Normal) {
+    if (mCpuIds.empty()) {
         return;
     }
     auto tid = MNNGetCurrentPid();
@@ -87,36 +88,11 @@ void CPURuntime::_bindCPUCore() const {
     }
     mCurrentTID = tid;
     // Bind CPU Core
-    auto cpuInfo = MNNGetCPUInfo();
-    if (cpuInfo->groups.size() == 0) {
-        return;
-    }
     std::vector<std::pair<const int*, int>> lockCPUIndexes(mThreadNumber);
-    switch (mPower) {
-        case BackendConfig::Power_Low:
-            for (int v=0; v<mThreadNumber; ++v) {
-                lockCPUIndexes[v] = std::make_pair(cpuInfo->groups[0].ids.data(), cpuInfo->groups[0].ids.size());
-            }
-            break;
-        case BackendConfig::Power_High:
-        {
-            int selectCPUSize = 0;
-            int groupIndex = cpuInfo->groups.size() - 1;
-            while (selectCPUSize < mThreadNumber && groupIndex >= 0) {
-                auto& group = cpuInfo->groups[groupIndex];
-                int size = ALIMIN(group.ids.size(), mThreadNumber - selectCPUSize);
-                for (int v=0; v<size; ++v) {
-                    lockCPUIndexes[v + selectCPUSize] = std::make_pair(group.ids.data(), group.ids.size());
-                }
-                groupIndex--;
-                selectCPUSize += group.ids.size();
-            }
-        }
-            break;
-        default:
-            break;
+    for (int v=0; v<mThreadNumber; ++v) {
+        lockCPUIndexes[v] = std::make_pair(mCpuIds.data(), mCpuIds.size());
     }
-        // Set CPU Affinity
+    // Set CPU Affinity
 #ifdef _OPENMP
     auto threadsNumber = mThreadNumber;
     std::vector<int> result(threadsNumber, 0);
@@ -126,30 +102,29 @@ void CPURuntime::_bindCPUCore() const {
     }
 #endif
 #ifdef MNN_USE_THREAD_POOL
-    ThreadPool::active(mThreadNumber);
-    ThreadPool::enqueue(std::make_pair([&](int i) {
-        MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second);
-        return 0;
-    }, mThreadNumber), mTaskIndex, mThreadNumber);
-    ThreadPool::deactive(mThreadNumber);
+    if(mThreadPool) {
+        mThreadPool->active();
+        mThreadPool->enqueue(std::make_pair([&](int i) {
+            MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second);
+            return 0;
+        }, mThreadNumber), mTaskIndex);
+        mThreadPool->deactive();
+    }
 #endif
 }
 
-void CPURuntime::_resetThreadPool() {
+void CPURuntime::_resetThreadPool() const{
     mThreadNumber = std::max(1, mThreadNumber);
     mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER);
 #ifdef MNN_USE_THREAD_POOL
-    ThreadPool::releaseWorkIndex(mTaskIndex);
-    auto cpuInfo = MNNGetCPUInfo();
-    if (mThreadNumber > 1) {
-        int systemThreadNumber = (int)cpuInfo->cpuNumber;
-        if (systemThreadNumber == 0) {
-            systemThreadNumber = mThreadNumber;
-        }
-        mThreadNumber = ALIMIN(ThreadPool::init(systemThreadNumber), mThreadNumber);
+    if (mThreadPool) {
+        mThreadPool->releaseWorkIndex(mTaskIndex);
     }
     if (mThreadNumber > 1) {
-        mTaskIndex = ThreadPool::acquireWorkIndex();
+        mThreadNumber = ALIMIN(ThreadPool::init(mThreadNumber, mCpuMask, mThreadPool), mThreadNumber);
+        if (mThreadPool) {
+            mTaskIndex = mThreadPool->acquireWorkIndex();
+        }
         if (-1 == mTaskIndex) {
             MNN_ERROR("The ThreadPool has been used to MNN_THREAD_POOL_MAX_TASKS, can't use thread pool\n");
             mThreadNumber = 1;
@@ -161,6 +136,86 @@ void CPURuntime::_resetThreadPool() {
     // Reset tid to rebind cpu if necessary
     mCurrentTID = 0;
 }
+void CPURuntime::_validateCpuIds() const{
+    bool valid = true;
+
+    do {
+        if (mCpuIds.empty()) {
+            valid = false;
+            break;
+        }
+
+        auto cpuInfo = MNNGetCPUInfo();
+        if (cpuInfo->groups.empty()) {
+            valid = false;
+            break;
+        }
+
+        std::unordered_map<int, bool> cpuLittleMap;
+        for (auto id : cpuInfo->groups[0].ids) {
+            cpuLittleMap[id] = true;
+        }
+        for (size_t i = 1; i < cpuInfo->groups.size(); i++) {
+            for (auto id : cpuInfo->groups[i].ids) {
+                cpuLittleMap[id] = false;
+            }
+        }
+
+        if (cpuLittleMap.find(mCpuIds[0]) == cpuLittleMap.end()) {
+            MNN_ERROR("CPU ID %d is not valid. CpuIds will not be used.\n", mCpuIds[0]);
+            valid = false;
+            break;
+        }
+
+        auto cpuLittle = cpuLittleMap[mCpuIds[0]];
+        for (size_t i = 1; i < mCpuIds.size(); i++) {
+            if (cpuLittleMap.find(mCpuIds[i]) == cpuLittleMap.end()) {
+                MNN_ERROR("CPU ID %d is not valid. CpuIds will not be used.\n", mCpuIds[i]);
+                valid = false;
+                break;
+            }
+            // Using the same group of CPU cores helps maximize multi thread performance.
+            // Mixing little cores with others can lead to significant performance degradation, so it is strictly prohibited.
+            // Even on architectures with more than two clusters, when little cores are not involved,
+            // it's still strongly recommended to avoid cross-cluster usage between different big core groups.
+            if (cpuLittleMap[mCpuIds[i]] != cpuLittle) {
+                MNN_ERROR("CPU ID %d and %d are not from the same group. CpuIds will not be used.\n", mCpuIds[0], mCpuIds[i]);
+                valid = false;
+                break;
+            }
+        }
+
+    } while (false);
+
+    if(!valid) {
+        mCpuIds.clear();
+    }
+
+    if(mCpuIds.empty()) {
+        auto cpuInfo = MNNGetCPUInfo();
+        if (cpuInfo->groups.size() == 0) {
+            return;
+        }
+        switch (mPower) {
+            case BackendConfig::Power_Low:
+                    mCpuIds = cpuInfo->groups[0].ids;
+                break;
+            case BackendConfig::Power_High: {
+                int selectCPUSize = 0;
+                int groupIndex = cpuInfo->groups.size() - 1;
+                while (selectCPUSize < mThreadNumber && groupIndex >= 0) {
+                    auto& group = cpuInfo->groups[groupIndex];
+                    mCpuIds.insert(mCpuIds.end(), group.ids.begin(), group.ids.end());
+                    groupIndex--;
+                    selectCPUSize += group.ids.size();
+                }
+            }
+                break;
+            default:
+                break;
+        }
+    }
+}
 void CPURuntime::onReset(int numberThread, const BackendConfig* config, bool full) {
     if (config != nullptr) {
         mPower = config->power;
@@ -171,6 +226,9 @@ void CPURuntime::onReset(int numberThread, const BackendConfig* config, bool ful
         }
     }
     mThreadNumber = numberThread;
+    mCpuIds = hint().cpuIds;
+    _validateCpuIds();
+    mCpuMask = MNNGetCPUMask(mCpuIds);
     _resetThreadPool();
 }
 
@@ -185,13 +243,13 @@ CPURuntime::CPURuntime(const Backend::Info& info) {
     mPower   = BackendConfig::Power_Normal;
     mMemory  = BackendConfig::Memory_Normal;
     mPrecision = BackendConfig::Precision_Normal;
+    mCpuIds.clear();
     if (info.user != nullptr) {
         mPrecision = info.user->precision;
         mPower = info.user->power;
         mMemory = info.user->memory;
         mFlags = info.user->flags;
     }
-    _resetThreadPool();
 #ifdef LOG_VERBOSE
     MNN_PRINT("create CPURuntime:%p\n", this);
 #endif
@@ -199,7 +257,9 @@ CPURuntime::CPURuntime(const Backend::Info& info) {
 
 CPURuntime:: ~ CPURuntime() {
 #ifdef MNN_USE_THREAD_POOL
-    ThreadPool::releaseWorkIndex(mTaskIndex);
+    if(mThreadPool) {
+        mThreadPool->releaseWorkIndex(mTaskIndex);
+    }
 #endif
 }
 float CPURuntime::onGetMemoryInMB() {
@@ -222,6 +282,12 @@ SingleBufferWithAllocator* CPURuntime::buffer(int index) const {
 }
 
 Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) const {
+    {
+        mCpuIds = hint().cpuIds;
+        _validateCpuIds();
+        mCpuMask = MNNGetCPUMask(mCpuIds);
+        _resetThreadPool();
+    }
     if (hint().midMemoryPath.size() > 0) {
         if (mDynamicMmap.empty()) {
             // Only support set featuremap dir once
@@ -270,25 +336,24 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
     MNN_PRINT("cpu backend was created by runtime:%p\n", this);
 #endif
     CPUBackend* res = nullptr;
-    auto initThreadNumber = hint().initThreadNumber;
     do {
 #ifdef MNN_USE_ARMV82
         auto core = MNNGetCoreFunctions();
         if (core->supportFp16arith && precision == BackendConfig::Precision_Low) {
-            res = new Arm82Backend(this, memory, initThreadNumber);
+            res = new Arm82Backend(this, memory);
             break;
         }
 #endif
 #ifdef MNN_SUPPORT_BF16
         if (precision == BackendConfig::Precision_Low_BF16 && BF16Functions::get()) {
-            res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU_EXTENSION, 0, initThreadNumber);
+            res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU_EXTENSION);
             res->mCoreFunctions = BF16Functions::get();
             break;
         }
 #endif
         if (flags == MNN_CPU_USE_DEFAULT_BACKEND) {
             // Default don't use multi-thread init
-            res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, 0, 0);
+            res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, 0);
             break;
         }
 #ifdef MNN_USE_SSE
@@ -297,7 +362,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
             break;
         }
 #endif
-        res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, flags, initThreadNumber);
+        res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, flags);
     } while (false);
     mSharedDmaInfo = nullptr;
     return res;
@@ -338,9 +403,9 @@ void CPURuntime::onGabageCollect(int level) {
 void CPURuntime::onConcurrencyBegin() const {
 #ifdef MNN_USE_THREAD_POOL
     if (mTaskIndex >= 0) {
-        if (mThreadOpen == 0) {
+        if (mThreadOpen == 0 && mThreadPool) {
             // mThreadOpen 0 -> 1, open ThreadPool
-            ThreadPool::active(mThreadNumber);
+            mThreadPool->active();
         }
         mThreadOpen++;
     }
@@ -359,8 +424,8 @@ void CPURuntime::onConcurrencyEnd() const {
         MNN_ASSERT(mThreadOpen > 0);
         mThreadOpen--;
         mThreadOpen = mThreadOpen < 0 ? 0 : mThreadOpen;
-        if (0 == mThreadOpen) {
-            ThreadPool::deactive(mThreadNumber);
+        if (0 == mThreadOpen && mThreadPool) {
+            mThreadPool->deactive();
         }
     }
 #endif
@@ -389,13 +454,16 @@ BufferAllocator* CPURuntime::createDynamicBufferAlloctor(int index) const {
     }
     return new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticAllocator.get()));
 }
-CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, MNNForwardType type, size_t flags, int initThreadNumber) : Backend(type) {
+CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, MNNForwardType type, size_t flags) : Backend(type) {
 #ifdef LOG_VERBOSE
     MNN_PRINT("cpu backend create\n");
 #endif
     mMemory = memory;
     mRuntime = const_cast<CPURuntime*>(runtime);
     mThreadNumber = mRuntime->mThreadNumber;
+#ifdef MNN_USE_THREAD_POOL
+    mThreadPool = mRuntime->mThreadPool;
+#endif
     // Compute Group Rate
     do {
         if (mThreadNumber <= 1 || mRuntime->mPower == BackendConfig::Power_Low) {
@@ -452,13 +520,6 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
         mCacheGroup[i].reset(new CPUResizeCache);
     }
     mCache = mCacheGroup[0].get();
-#if 0
-#ifndef MNN_FORBIT_MULTI_THREADS
-    if (initThreadNumber > 0) {
-        mInitWorkQueue.reset(new WorkerThread(initThreadNumber));
-    }
-#endif
-#endif
 }
 
 CPUBackend::~CPUBackend() {

+ 15 - 13
source/backend/cpu/CPUBackend.hpp

@@ -17,6 +17,10 @@
 #include "core/BufferAllocator.hpp"
 #include "MNN_generated.h"
 
+#ifdef MNN_USE_THREAD_POOL
+#include "ThreadPool.hpp"
+#endif
+
 #ifdef MNN_KLEIDIAI_ENABLED
 #include "arm/mnn_kleidiai.h"
 #endif
@@ -45,22 +49,21 @@ public:
     virtual void onConcurrencyEnd() const override;
     virtual bool onCheckInfo(Backend::Info& info) const override;
 
-#ifdef MNN_USE_THREAD_POOL
-    inline bool multiThreadValid() const {
-        return mThreadOpen;
-    }
-#endif
     SingleBufferWithAllocator* buffer(int index) const;
     BufferAllocator* createDynamicBufferAlloctor(int index) const;
 
 private:
     void _bindCPUCore() const;
-    void _resetThreadPool();
+    void _resetThreadPool() const;
+    void _validateCpuIds() const;
     mutable std::shared_ptr<EagerBufferAllocator> mStaticAllocator;
-    int mThreadNumber;
+    mutable int mThreadNumber;
+    mutable std::vector<int> mCpuIds;
+    mutable unsigned long mCpuMask;
 #ifdef MNN_USE_THREAD_POOL
     mutable int mTaskIndex = -1;
     mutable int mThreadOpen = 0;
+    mutable ThreadPool* mThreadPool = nullptr;
 #endif
     BackendConfig::MemoryMode mMemory;
     BackendConfig::PowerMode mPower;
@@ -102,7 +105,7 @@ private:
 };
 class CPUBackend : public Backend {
 public:
-    CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, MNNForwardType type = MNN_FORWARD_CPU, size_t flags = 0, int initThreadNumber = 0);
+    CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, MNNForwardType type = MNN_FORWARD_CPU, size_t flags = 0);
     virtual ~CPUBackend();
 
     // Return sizeDivide, scheduleNumber aligned memory
@@ -149,11 +152,6 @@ public:
     inline int threadNumber() const {
         return mThreadNumber;
     }
-#ifdef MNN_USE_THREAD_POOL
-    inline bool threadOpen() const {
-        return mRuntime->mThreadOpen > 0;
-    }
-#endif
 
     BufferAllocator* getBufferAllocator(bool defer_allocator = true) const {
         return mDmaInfo->mCurrentDynamicAllocator;
@@ -173,6 +171,7 @@ public:
 
 #ifdef MNN_USE_THREAD_POOL
     inline int taskIndex() const {return mRuntime->mTaskIndex;}
+    inline ThreadPool* threadPool() const {return mRuntime->mThreadPool;}
 #endif
     static void initCreatorMap();
     static size_t getBytes(const Backend* backend, const Tensor* output);
@@ -187,6 +186,9 @@ protected:
 private:
     mutable std::shared_ptr<WorkerThread> mInitWorkQueue;
     int mThreadNumber;
+#ifdef MNN_USE_THREAD_POOL
+    ThreadPool* mThreadPool = nullptr;
+#endif
     std::vector<std::pair<float, int>> mGroupWithComputeRate;
     float mComputeI = 0.f;
 

+ 64 - 11
source/backend/cpu/CPURuntime.cpp

@@ -82,22 +82,56 @@ int MNNGetCurrentPid() {
     return 0;
 #endif
 }
-int MNNSetSchedAffinity(const int* cpuIDs, int size) {
+
 #if defined (__linux__)
-#ifndef CPU_SETSIZE
+// Referenced from: (LINUX) bits/cpu-set.h
+// https://sourceware.org/git/?p=glibc.git;a=blob_plain;f=posix/bits/cpu-set.h;hb=HEAD
+// Copied from: (ANDROID) libc/include/sched.h
+// https://android.googlesource.com/platform/bionic.git/+/master/libc/include/sched.h
+#ifdef __LP64__
 #define CPU_SETSIZE 1024
+#else
+#define CPU_SETSIZE 32
+#endif
+#define __CPU_BITTYPE  unsigned long int  /* mandated by the kernel  */
+#define __CPU_BITS     (8 * sizeof(__CPU_BITTYPE))
+#define __CPU_ELT(x)   ((x) / __CPU_BITS)
+#define __CPU_MASK(x)  ((__CPU_BITTYPE)1 << ((x) & (__CPU_BITS - 1)))
+/**
+ * [CPU_ZERO](https://man7.org/linux/man-pages/man3/CPU_ZERO.3.html) clears all
+ * bits in a static CPU set.
+ */
+#define CPU_ZERO(set) CPU_ZERO_S(sizeof(cpu_set_t), set)
+/**
+ * [CPU_ZERO_S](https://man7.org/linux/man-pages/man3/CPU_ZERO_S.3.html) clears
+ * all bits in a dynamic CPU set allocated by `CPU_ALLOC`.
+ */
+#define CPU_ZERO_S(setsize, set) __builtin_memset(set, 0, setsize)
+/**
+ * [CPU_SET](https://man7.org/linux/man-pages/man3/CPU_SET.3.html) sets one
+ * bit in a static CPU set.
+ */
+#define CPU_SET(cpu, set) CPU_SET_S(cpu, sizeof(cpu_set_t), set)
+/**
+ * [CPU_SET_S](https://man7.org/linux/man-pages/man3/CPU_SET_S.3.html) sets one
+ * bit in a dynamic CPU set allocated by `CPU_ALLOC`.
+ */
+#define CPU_SET_S(cpu, setsize, set)                              \
+    do {                                                          \
+        size_t __cpu = (cpu);                                     \
+        if (__cpu < 8 * (setsize))                                \
+            (set)->__bits[__CPU_ELT(__cpu)] |= __CPU_MASK(__cpu); \
+    } while (0)
 #endif
-#define __NCPUBITS (8 * sizeof(unsigned long))
+int MNNSetSchedAffinity(const int* cpuIDs, int size) {
+#if defined (__linux__)
+    /**
+     * [cpu_set_t](https://man7.org/linux/man-pages/man3/CPU_SET.3.html) is a
+     * statically-sized CPU set. See `CPU_ALLOC` for dynamically-sized CPU sets.
+     */
     typedef struct {
-        unsigned long __bits[CPU_SETSIZE / __NCPUBITS];
+        __CPU_BITTYPE __bits[CPU_SETSIZE / __CPU_BITS];
     } cpu_set_t;
-
-#ifndef CPU_SET
-#define CPU_SET(cpu, cpusetp) ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
-#endif
-#ifndef CPU_ZERO
-#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t))
-#endif
     // set affinity for thread
     pid_t pid = MNNGetCurrentPid();
     cpu_set_t mask;
@@ -115,6 +149,25 @@ int MNNSetSchedAffinity(const int* cpuIDs, int size) {
     return 0;
 }
 
+cpu_mask_t MNNGetCPUMask(const std::vector<int>& cpuIds) {
+#if defined (__linux__)
+    /**
+     * [cpu_set_t](https://man7.org/linux/man-pages/man3/CPU_SET.3.html) is a
+     * statically-sized CPU set. See `CPU_ALLOC` for dynamically-sized CPU sets.
+     */
+    typedef struct {
+        __CPU_BITTYPE __bits[CPU_SETSIZE / __CPU_BITS];
+    } cpu_set_t;
+    cpu_set_t cpuMask;
+    CPU_ZERO(&cpuMask);
+    for (auto i :cpuIds){
+        CPU_SET(i, &cpuMask);
+    }
+    return cpuMask.__bits[0];
+#endif
+    return 0;
+}
+
 // cpuinfo
 // Reference from: https://github.com/pytorch/cpuinfo
 #if defined(ENABLE_ARMV82) && defined(__arm__)

+ 2 - 1
source/backend/cpu/CPURuntime.hpp

@@ -25,9 +25,10 @@ struct MNNCPUInfo {
     std::vector<CPUGroup> groups;
     int cpuNumber = 0;
 };
-
+using cpu_mask_t = unsigned long;
 int MNNSetSchedAffinity(const int* cpuIDs, int size);
 int MNNGetCurrentPid();
+cpu_mask_t MNNGetCPUMask(const std::vector<int>& cpuIds);
 const MNNCPUInfo* MNNGetCPUInfo();
 
 #endif /* CPUInfo_hpp */

+ 43 - 60
source/backend/cpu/ThreadPool.cpp

@@ -8,41 +8,43 @@
 #ifdef MNN_USE_THREAD_POOL
 #include "backend/cpu/ThreadPool.hpp"
 #include <string.h>
+#include <unordered_map>
 #include <MNN/MNNDefine.h>
+#include "ThreadPool.hpp"
 
 #define MNN_THREAD_POOL_MAX_TASKS 2
 namespace MNN {
-ThreadPool* ThreadPool::gInstance = nullptr;
+static std::unordered_map<long int, ThreadPool*> gInstances;
 static std::mutex gInitMutex;
-int ThreadPool::init(int number) {
-    if (1 >= number) {
-        return 1;
+int ThreadPool::init(int numberThread, unsigned long cpuMask, ThreadPool*& threadPool) {
+    if (1 >= numberThread) {
+        numberThread = 1;
     }
     std::lock_guard<std::mutex> _l(gInitMutex);
-    if (nullptr != gInstance) {
-        if (gInstance->number() < number) {
-            return gInstance->number();
-        }
+
+    if (gInstances.find(cpuMask) == gInstances.end()){
+        gInstances[cpuMask] = new ThreadPool(numberThread);
     }
-    if (nullptr == gInstance) {
-        gInstance = new ThreadPool(number);
+    threadPool = gInstances[cpuMask];
+    if (gInstances[cpuMask]->numberThread() < numberThread){
+        return gInstances[cpuMask]->numberThread();
     }
-    return number;
+    return numberThread;
 }
+
 void ThreadPool::destroy() {
     std::lock_guard<std::mutex> _l(gInitMutex);
-    if (nullptr != gInstance) {
-        delete gInstance;
-        gInstance = nullptr;
+    for (auto i= gInstances.begin(); i != gInstances.end(); i++){
+        if (i->second){
+            delete i->second;
+        }
     }
+    gInstances.clear();
 }
 
 ThreadPool::ThreadPool(int numberThread) {
     mNumberThread = numberThread;
-    mActiveCount.resize(numberThread);
-    for (int i=0; i<numberThread; ++i) {
-        mActiveCount[i] = new std::atomic_int(0);
-    }
+    mActiveCount  = 0;
     mTaskAvailable.resize(MNN_THREAD_POOL_MAX_TASKS);
     mTasks.resize(MNN_THREAD_POOL_MAX_TASKS);
     for (int t = 0; t < mTasks.size(); ++t) {
@@ -55,7 +57,7 @@ ThreadPool::ThreadPool(int numberThread) {
         int threadIndex = i;
         mWorkers.emplace_back([this, threadIndex]() {
             while (!mStop) {
-                while (*mActiveCount[threadIndex] > 0) {
+                while (mActiveCount > 0) {
                     for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) {
                         if (*mTasks[i].second[threadIndex]) {
                             mTasks[i].first.first(threadIndex);
@@ -65,7 +67,7 @@ ThreadPool::ThreadPool(int numberThread) {
                     std::this_thread::yield();
                 }
                 std::unique_lock<std::mutex> _l(mQueueMutex);
-                mCondition.wait(_l, [this, threadIndex] { return mStop || *mActiveCount[threadIndex] > 0; });
+                mCondition.wait(_l, [this] { return mStop || mActiveCount > 0; });
             }
         });
     }
@@ -85,82 +87,63 @@ ThreadPool::~ThreadPool() {
             delete c;
         }
     }
-    for (int i=0; i<mActiveCount.size(); ++i) {
-        delete mActiveCount[i];
-    }
 }
 
 int ThreadPool::acquireWorkIndex() {
-    if (nullptr == gInstance) {
-        return -1;
-    }
-    std::lock_guard<std::mutex> _l(gInstance->mQueueMutex);
+    std::lock_guard<std::mutex> _l(mQueueMutex);
     for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) {
-        if (gInstance->mTaskAvailable[i]) {
-            gInstance->mTaskAvailable[i] = false;
+        if (mTaskAvailable[i]) {
+            mTaskAvailable[i] = false;
             return i;
         }
     }
     return -1;
 }
 void ThreadPool::releaseWorkIndex(int index) {
-    if (nullptr == gInstance) {
-        return;
-    }
     if (index < 0 || index >= MNN_THREAD_POOL_MAX_TASKS) {
         return;
     }
-    std::lock_guard<std::mutex> _l(gInstance->mQueueMutex);
-    gInstance->mTaskAvailable[index] = true;
+    std::lock_guard<std::mutex> _l(mQueueMutex);
+    mTaskAvailable[index] = true;
 }
 
-void ThreadPool::active(int threadNumber) {
-    if (nullptr == gInstance) {
-        return;
-    }
+void ThreadPool::active() {
     {
-        std::lock_guard<std::mutex> _l(gInstance->mQueueMutex);
-        for (int i=0; i<threadNumber; ++i) {
-            (*gInstance->mActiveCount[i])++;
-        }
+        std::lock_guard<std::mutex> _l(mQueueMutex);
+        mActiveCount++;
     }
-    gInstance->mCondition.notify_all();
+    mCondition.notify_all();
 }
-void ThreadPool::deactive(int threadNumber) {
-    if (nullptr == gInstance) {
-        return;
-    }
-    for (int i=0; i<threadNumber; ++i) {
-        (*gInstance->mActiveCount[i])--;
-    }
+void ThreadPool::deactive() {
+    mActiveCount--;
 }
 
-void ThreadPool::enqueue(TASK&& task, int index, int threadNumber) {
+void ThreadPool::enqueue(TASK&& task, int index) {
     if (1 >= task.second || 0 > index) {
         for (int i = 0; i < task.second; ++i) {
             task.first(i);
         }
         return;
     }
-    MNN_ASSERT(nullptr != gInstance);
-    gInstance->enqueueInternal(std::move(task), index, threadNumber);
+    enqueueInternal(std::move(task), index);
 }
-void ThreadPool::enqueueInternal(TASK&& task, int index, int threadNumber) {
-    if (threadNumber <= 1) {
+void ThreadPool::enqueueInternal(TASK&& task, int index) {
+    if (mActiveCount == 0) {
         for (int i = 0; i < task.second; ++i) {
             task.first(i);
         }
         return;
     }
     int workSize = task.second;
-    if (workSize > threadNumber) {
+    if (workSize > mNumberThread) {
         mTasks[index].first = std::make_pair(
-            [workSize, &task, threadNumber, this](int tId) {
-                for (int v = tId; v < workSize; v += threadNumber) {
+            [workSize, &task, this](int tId) {
+                for (int v = tId; v < workSize; v += mNumberThread) {
                     task.first(v);
                 }
-            },threadNumber);
-        workSize = threadNumber;
+            },
+            mNumberThread);
+        workSize = mNumberThread;
     } else {
         mTasks[index].first = std::move(task);
     }

+ 10 - 11
source/backend/cpu/ThreadPool.hpp

@@ -22,25 +22,24 @@ class MNN_PUBLIC ThreadPool {
 public:
     typedef std::pair<std::function<void(int)>, int> TASK;
 
-    int number() const {
+    int numberThread() const {
         return mNumberThread;
     }
-    static void enqueue(TASK&& task, int index, int threadNumber);
+    void enqueue(TASK&& task, int index);
 
-    static void active(int threadNumber);
-    static void deactive(int threadNumber);
+    void active();
+    void deactive();
 
-    static int acquireWorkIndex();
-    static void releaseWorkIndex(int index);
+    int acquireWorkIndex();
+    void releaseWorkIndex(int index);
 
-    static int init(int number);
+    static int init(int numberThread, unsigned long cpuMask, ThreadPool*& threadPool);
     static void destroy();
 
 private:
-    void enqueueInternal(TASK&& task, int index, int threadNumber);
+    void enqueueInternal(TASK&& task, int index);
 
-    static ThreadPool* gInstance;
-    ThreadPool(int number = 0);
+    ThreadPool(int numberThread = 0);
     ~ThreadPool();
 
     std::vector<std::thread> mWorkers;
@@ -52,7 +51,7 @@ private:
     std::mutex mQueueMutex;
 
     int mNumberThread            = 0;
-    std::vector<std::atomic_int*> mActiveCount;
+    std::atomic_int mActiveCount = {0};
 };
 } // namespace MNN
 #endif

+ 3 - 0
source/core/Backend.hpp

@@ -58,6 +58,9 @@ struct RuntimeHint {
     // op encoder number for once commit
     int encorderNumForCommit = 10;
     int initThreadNumber = 0;
+
+    // cpu core ids
+    std::vector<int> cpuIds;
 };
 /** abstract backend */
 class Backend : public NonCopyable {

+ 2 - 1
source/core/Concurrency.h

@@ -28,7 +28,8 @@
     }                                                              \
     ;                                                              \
     auto cpuBn = (CPUBackend*)backend();                           \
-    MNN::ThreadPool::enqueue(std::move(task), cpuBn->taskIndex(), cpuBn->threadOpen() ? cpuBn->threadNumber() : 1); \
+    auto thrPl = cpuBn->threadPool();                              \
+    thrPl->enqueue(std::move(task), cpuBn->taskIndex());           \
     }
 
 #else

+ 16 - 12
source/core/Interpreter.cpp

@@ -140,22 +140,26 @@ Interpreter* Interpreter::createFromBufferInternal(Content* net, bool enforceAut
     return new Interpreter(net);
 }
 
-void Interpreter::setSessionHint(HintMode mode, int hint) {
-    mNet->modes.setHint(mode, hint);
+void Interpreter::setSessionHint(HintMode hint, int value) {
+    mNet->modes.setHint(hint, value);
+}
+
+void Interpreter::setSessionHint(HintMode hint, int* value, size_t size) {
+    mNet->modes.setHint(hint, value, size);
 }
 
 void Interpreter::setSessionMode(SessionMode mode) {
-    if (mode == Session_Resize_Check) {
-        for (auto& iter : mNet->sessions) {
-            iter->openResizeCheck();
-        }
-    } else if (mode == Session_Resize_Fix) {
-        for (auto& iter : mNet->sessions) {
-            iter->fixResizeCache();
-        }
-    } else {
-        mNet->modes.setMode(mode);
+  if (mode == Session_Resize_Check) {
+    for (auto& iter : mNet->sessions) {
+      iter->openResizeCheck();
+    }
+  } else if (mode == Session_Resize_Fix) {
+    for (auto& iter : mNet->sessions) {
+      iter->fixResizeCache();
     }
+  } else {
+    mNet->modes.setMode(mode);
+  }
 }
 
 void Interpreter::setCacheFile(const char* cacheFile, size_t keySize) {

+ 37 - 28
source/core/Session.cpp

@@ -68,46 +68,55 @@ void Session::ModeGroup::setMode(Interpreter::SessionMode mode) {
         codegenMode = mode;
     }
 }
-void Session::ModeGroup::setHint(Interpreter::HintMode mode, int hint) {
-    switch (mode) {
-        case Interpreter::MAX_TUNING_NUMBER:
-            maxTuningNumber = hint;
+void Session::ModeGroup::setHint(Interpreter::HintMode hint, int value) {
+    switch (hint) {
+        case Interpreter::HintMode::MAX_TUNING_NUMBER:
+            maxTuningNumber = value;
             break;
-        case Interpreter::MEM_ALLOCATOR_TYPE:
-            runtimeHint.memoryAllocatorType = hint;
+        case Interpreter::HintMode::MEM_ALLOCATOR_TYPE:
+            runtimeHint.memoryAllocatorType = value;
             break;
-        case Interpreter::WINOGRAD_MEMORY_LEVEL:
-            runtimeHint.winogradMemoryUsed = hint;
+        case Interpreter::HintMode::WINOGRAD_MEMORY_LEVEL:
+            runtimeHint.winogradMemoryUsed = value;
             break;
-        case Interpreter::CPU_LITTLECORE_DECREASE_RATE:
-            runtimeHint.cpuDecreaseRate = hint;
+        case Interpreter::HintMode::CPU_LITTLECORE_DECREASE_RATE:
+            runtimeHint.cpuDecreaseRate = value;
             break;
-        case Interpreter::GEOMETRY_COMPUTE_MASK:
-            geometryMask = hint;
+        case Interpreter::HintMode::GEOMETRY_COMPUTE_MASK:
+            geometryMask = value;
             break;
-        case Interpreter::STRICT_CHECK_MODEL:
-            checkNetBuffer = hint > 0;
+        case Interpreter::HintMode::STRICT_CHECK_MODEL:
+            checkNetBuffer = value > 0;
             break;
-        case Interpreter::DYNAMIC_QUANT_OPTIONS:
-            runtimeHint.dynamicQuantOption = hint;
+        case Interpreter::HintMode::DYNAMIC_QUANT_OPTIONS:
+            runtimeHint.dynamicQuantOption = value;
             break;
-        case Interpreter::QKV_QUANT_OPTIONS:
-            runtimeHint.qkvQuantOption = hint;
+        case Interpreter::HintMode::QKV_QUANT_OPTIONS:
+            runtimeHint.qkvQuantOption = value;
             break;
-        case Interpreter::KVCACHE_SIZE_LIMIT:
-            runtimeHint.kvcacheSizeLimit = hint;
+        case Interpreter::HintMode::KVCACHE_SIZE_LIMIT:
+            runtimeHint.kvcacheSizeLimit = value;
             break;
-        case Interpreter::OP_ENCODER_NUMBER_FOR_COMMIT:
-            runtimeHint.encorderNumForCommit = hint;
+        case Interpreter::HintMode::OP_ENCODER_NUMBER_FOR_COMMIT:
+            runtimeHint.encorderNumForCommit = value;
             break;
-        case Interpreter::MMAP_FILE_SIZE:
-            runtimeHint.mmapFileSize = hint;
+        case Interpreter::HintMode::MMAP_FILE_SIZE:
+            runtimeHint.mmapFileSize = value;
             break;
-        case Interpreter::USE_CACHED_MMAP:
-            runtimeHint.useCachedMmap = hint;
+        case Interpreter::HintMode::USE_CACHED_MMAP:
+            runtimeHint.useCachedMmap = value;
             break;
-        case Interpreter::INIT_THREAD_NUMBER:
-            runtimeHint.initThreadNumber = hint;
+        case Interpreter::HintMode::INIT_THREAD_NUMBER:
+            runtimeHint.initThreadNumber = value;
+            break;
+        default:
+            break;
+    }
+}
+void Session::ModeGroup::setHint(Interpreter::HintMode hint, int* value, size_t size) {
+    switch (hint) {
+        case Interpreter::HintMode::CPU_CORE_IDS:
+            runtimeHint.cpuIds = std::vector<int>(value, value + size);
             break;
         default:
             break;

+ 3 - 1
source/core/Session.hpp

@@ -37,7 +37,9 @@ public:
         int geometryMask = 0xFFFF;
         bool checkNetBuffer = true;
         RuntimeHint runtimeHint;
-        void setHint(Interpreter::HintMode hint, int magic);
+        void setHint(Interpreter::HintMode hint, int value);
+        void setHint(Interpreter::HintMode hint, int* value, size_t size);
+        void setHintPtr(Interpreter::HintMode hint, int value);
         void setMode(Interpreter::SessionMode mode);
         void setExternalPath(std::string path, int type);
     };

+ 7 - 6
test/core/ThreadPoolTest.cpp

@@ -20,18 +20,19 @@ public:
         std::vector<std::thread> threads;
         for (int i = 0; i < 10; ++i) {
             threads.emplace_back([i]() {
-                int number = MNN::ThreadPool::init(10 - i);
+                MNN::ThreadPool* threadPool = nullptr;
+                MNN::ThreadPool::init(10 - i, 0, threadPool);
                 // initializer
-                auto workIndex = ThreadPool::acquireWorkIndex();
+                auto workIndex = threadPool->acquireWorkIndex();
                 FUNC_PRINT(workIndex);
-                ThreadPool::active(number);
+                threadPool->active();
                 auto func = [](int index) {
                     FUNC_PRINT(index);
                     std::this_thread::yield();
                 };
-                ThreadPool::enqueue(std::make_pair(std::move(func), 10), workIndex, number);
-                ThreadPool::deactive(number);
-                ThreadPool::releaseWorkIndex(workIndex);
+                threadPool->enqueue(std::make_pair(std::move(func), 10), workIndex);
+                threadPool->deactive();
+                threadPool->releaseWorkIndex(workIndex);
             });
         }
         for (auto& t : threads) {

+ 34 - 16
tools/cpp/MNNV2Basic.cpp

@@ -167,11 +167,28 @@ static inline int64_t getTimeInUs() {
     return time;
 }
 
+static inline std::vector<int> parseIntList(const std::string& str, char delim) {
+    std::vector<int> result;
+    std::ptrdiff_t p1 = 0, p2;
+    while (1) {
+        p2 = str.find(delim, p1);
+        if (p2 != std::string::npos) {
+            result.push_back(atoi(str.substr(p1, p2 - p1).c_str()));
+            p1 = p2 + 1;
+        } else {
+            result.push_back(atoi(str.substr(p1).c_str()));
+            break;
+        }
+    }
+    return result;
+}
+
 static int test_main(int argc, const char* argv[]) {
     if (argc < 2) {
-        MNN_PRINT("========================================================================\n");
-        MNN_PRINT("Arguments: model.MNN runLoops runMask forwardType numberThread precision inputSize \n");
-        MNN_PRINT("========================================================================\n");
+        MNN_PRINT("=========================================================================================\n");
+        MNN_PRINT("Arguments: model.MNN runLoops runMask forwardType numberThread precision inputSize cpuIds\n");
+        MNN_PRINT("Example: %s model.MNN 100 0 0 4 0 1x3x224x224 0,1,2,3\n", argv[0]);
+        MNN_PRINT("=========================================================================================\n");
         return -1;
     }
 
@@ -227,25 +244,25 @@ static int test_main(int argc, const char* argv[]) {
     // input dims
     std::vector<int> inputDims;
     if (argc > 7) {
-        std::string inputShape(argv[7]);
-        const char* delim = "x";
-        std::ptrdiff_t p1 = 0, p2;
-        while (1) {
-            p2 = inputShape.find(delim, p1);
-            if (p2 != std::string::npos) {
-                inputDims.push_back(atoi(inputShape.substr(p1, p2 - p1).c_str()));
-                p1 = p2 + 1;
-            } else {
-                inputDims.push_back(atoi(inputShape.substr(p1).c_str()));
-                break;
-            }
-        }
+        inputDims = parseIntList(argv[7], 'x');
     }
+    MNN_PRINT("inputDims: ");
     for (auto dim : inputDims) {
         MNN_PRINT("%d ", dim);
     }
     MNN_PRINT("\n");
 
+    // CPU IDs
+    std::vector<int> cpuIds;
+    if (argc > 8) {
+        cpuIds = parseIntList(argv[8], ',');
+    }
+    MNN_PRINT("cpuIds: ");
+    for (auto id : cpuIds) {
+        MNN_PRINT("%d ", id);
+    }
+    MNN_PRINT("\n");
+
     // create net
     MNN_PRINT("Open Model %s\n", fileName);
     std::shared_ptr<MNN::Interpreter> net =
@@ -265,6 +282,7 @@ static int test_main(int argc, const char* argv[]) {
     if (runMask & 32) {
         net->setSessionHint(Interpreter::WINOGRAD_MEMORY_LEVEL, 0);
     }
+    net->setSessionHint(Interpreter::HintMode::CPU_CORE_IDS, cpuIds.data(), cpuIds.size());
 
     // create session
     MNN::ScheduleConfig config;

+ 30 - 1
tools/cpp/ModuleBasic.cpp

@@ -90,9 +90,27 @@ static bool compareOutput(VARP output, const std::string& directName, const std:
     }
     return true;
 }
+
+static inline std::vector<int> parseIntList(const std::string& str, char delim) {
+    std::vector<int> result;
+    std::ptrdiff_t p1 = 0, p2;
+    while (1) {
+        p2 = str.find(delim, p1);
+        if (p2 != std::string::npos) {
+            result.push_back(atoi(str.substr(p1, p2 - p1).c_str()));
+            p1 = p2 + 1;
+        } else {
+            result.push_back(atoi(str.substr(p1).c_str()));
+            break;
+        }
+    }
+    return result;
+}
 int main(int argc, char *argv[]) {
     if (argc < 3) {
-        MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile]\n");
+        MNN_PRINT("=======================================================================================================================================\n");
+        MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds]\n");
+        MNN_PRINT("=======================================================================================================================================\n");
         return 0;
     }
     BackendConfig backendConfigTmp;
@@ -220,6 +238,16 @@ int main(int argc, char *argv[]) {
     if (argc > 8) {
         cacheFileName = argv[8];
     }
+    // CPU IDs
+    std::vector<int> cpuIds;
+    if (argc > 9) {
+        cpuIds = parseIntList(argv[9], ',');
+    }
+    MNN_PRINT("cpuIds: ");
+    for (auto id : cpuIds) {
+        MNN_PRINT("%d ", id);
+    }
+    MNN_PRINT("\n");
     FUNC_PRINT(precision);
     FUNC_PRINT(memory);
     FUNC_PRINT(power);
@@ -246,6 +274,7 @@ int main(int argc, char *argv[]) {
     std::shared_ptr<Executor::RuntimeManager> rtmgr(Executor::RuntimeManager::createRuntimeManager(config));
     rtmgr->setCache(cacheFileName);
     rtmgr->setHint(MNN::Interpreter::INIT_THREAD_NUMBER, 4);
+    rtmgr->setHint(MNN::Interpreter::HintMode::CPU_CORE_IDS, cpuIds.data(), cpuIds.size());
 
     if (cpuDecreaseRate > 0 && cpuDecreaseRate <= 100) {
         rtmgr->setHint(Interpreter::CPU_LITTLECORE_DECREASE_RATE, cpuDecreaseRate);

+ 38 - 14
tools/cpp/timeProfile.cpp

@@ -23,7 +23,30 @@
 
 using namespace MNN;
 
+static inline std::vector<int> parseIntList(const std::string& str, char delim) {
+    std::vector<int> result;
+    std::ptrdiff_t p1 = 0, p2;
+    while (1) {
+        p2 = str.find(delim, p1);
+        if (p2 != std::string::npos) {
+            result.push_back(atoi(str.substr(p1, p2 - p1).c_str()));
+            p1 = p2 + 1;
+        } else {
+            result.push_back(atoi(str.substr(p1).c_str()));
+            break;
+        }
+    }
+    return result;
+}
 int main(int argc, const char* argv[]) {
+    if (argc < 2) {
+        MNN_PRINT("=========================================================================================\n");
+        MNN_PRINT("Arguments: model.MNN runLoops forwardType inputSize numberThread precision sparsity cpuIds\n");
+        MNN_PRINT("Example: %s model.MNN 100 0 1x3x224x224 4 0 0 0,1,2,3\n", argv[0]);
+        MNN_PRINT("=========================================================================================\n");
+        return -1;
+    }
+
     std::string cmd = argv[0];
     std::string pwd = "./";
     auto rslash     = cmd.rfind("/");
@@ -46,20 +69,9 @@ int main(int argc, const char* argv[]) {
     // input dims
     std::vector<int> inputDims;
     if (argc > 4) {
-        std::string inputShape(argv[4]);
-        const char* delim = "x";
-        std::ptrdiff_t p1 = 0, p2;
-        while (1) {
-            p2 = inputShape.find(delim, p1);
-            if (p2 != std::string::npos) {
-                inputDims.push_back(atoi(inputShape.substr(p1, p2 - p1).c_str()));
-                p1 = p2 + 1;
-            } else {
-                inputDims.push_back(atoi(inputShape.substr(p1).c_str()));
-                break;
-            }
-        }
+        inputDims = parseIntList(argv[4], 'x');
     }
+    MNN_PRINT("inputDims: ");
     for (auto dim : inputDims) {
         MNN_PRINT("%d ", dim);
     }
@@ -77,9 +89,20 @@ int main(int argc, const char* argv[]) {
     }
 
     float sparsity = 0.0f;
-    if(argc >= 8) {
+    if(argc > 7) {
         sparsity = atof(argv[7]);
     }
+    
+    // CPU IDs
+    std::vector<int> cpuIds;
+    if (argc > 8) {
+        cpuIds = parseIntList(argv[8], ',');
+    }
+    MNN_PRINT("cpuIds: ");
+    for (auto id : cpuIds) {
+        MNN_PRINT("%d ", id);
+    }
+    MNN_PRINT("\n");
 
 
     // revert MNN model if necessary
@@ -96,6 +119,7 @@ int main(int argc, const char* argv[]) {
     }
     revertor.reset();
     net->setSessionMode(Interpreter::Session_Debug);
+    net->setSessionHint(Interpreter::HintMode::CPU_CORE_IDS, cpuIds.data(), cpuIds.size());
 
     // create session
     MNN::ScheduleConfig config;