Expr.cpp 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019
  1. //
  2. // Expr.cpp
  3. // MNN
  4. //
  5. // Created by MNN on 2019/06/10.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #define FLATBUFFERS_PREFER_PRINTF
  9. #include <MNN/expr/Expr.hpp>
  10. #include <MNN/expr/Executor.hpp>
  11. #include <MNN/expr/ExprCreator.hpp>
  12. #include "Utils.hpp"
  13. #include "core/FileLoader.hpp"
  14. #include "core/TensorUtils.hpp"
  15. #include "core/WrapExecution.hpp"
  16. #include "MNN_generated.h"
  17. //#define MNN_OPEN_TIME_TRACE
  18. #include "MNN/AutoTime.hpp"
  19. #include "MNN/expr/ExecutorScope.hpp"
  20. #include "half.hpp"
  21. //#define MNN_EXPRESS_ERROR_REPORT
  22. static inline std::string numberToString(int index) {
  23. char s[10];
  24. snprintf(s, 10, "%d", index);
  25. return std::string(s);
  26. }
  27. static bool HasUnknownDim(const std::vector<int>& dims) {
  28. for (const int& dim : dims) {
  29. if (dim < 0) {
  30. return true;
  31. }
  32. }
  33. return false;
  34. }
  35. namespace MNN {
  36. namespace Express {
  37. void Variable::Info::syncSize() {
  38. size = 1;
  39. for (int i=0; i<dim.size(); ++i) {
  40. if (dim[i] <= 0) {
  41. // Not valid
  42. size = 0;
  43. return;
  44. }
  45. if (order == NC4HW4 && i == 1) {
  46. size *= (UP_DIV(dim[1], 4) * 4);
  47. } else {
  48. size *= dim[i];
  49. }
  50. }
  51. }
  52. bool VARP::fix(VARP::InputType type) const {
  53. if (nullptr == mContent->expr().first->get()) {
  54. mContent->expr().first->mType = type;
  55. return true;
  56. }
  57. auto info = mContent->getInfo();
  58. if (nullptr == info) {
  59. return false;
  60. }
  61. VARP newVar;
  62. switch (type) {
  63. case INPUT: {
  64. newVar = _Input(info->dim, info->order, info->type);
  65. auto ptr = mContent->readMap<void>();
  66. if (nullptr != ptr) {
  67. auto dstPtr = newVar->writeMap<void>();
  68. ::memcpy(dstPtr, ptr, info->size * info->type.bytes());
  69. }
  70. break;
  71. }
  72. case CONSTANT: {
  73. auto ptr = mContent->readMap<void>();
  74. if (nullptr == ptr) {
  75. return false;
  76. }
  77. newVar = _Const(ptr, info->dim, info->order, info->type);
  78. break;
  79. }
  80. case TRAINABLE: {
  81. auto ptr = mContent->readMap<void>();
  82. if (nullptr == ptr) {
  83. return false;
  84. }
  85. newVar = _TrainableParam(ptr, info->dim, info->order, info->type);
  86. break;
  87. }
  88. default:
  89. return false;
  90. }
  91. Variable::replace(VARP(mContent), newVar);
  92. return true;
  93. }
  94. Expr::Expr(int outputSize) {
  95. mInside.reset(new Inside(outputSize));
  96. mOutputNames.resize(outputSize);
  97. }
  98. Expr::Expr(Tensor* tensor, bool own) {
  99. mInside.reset(new Inside(tensor, own));
  100. mOutputNames.resize(1);
  101. }
  102. Expr::~Expr() {
  103. mInside.reset();
  104. }
  105. Variable::Info* Expr::outputInfo(int index) const {
  106. return mInside->mOutputInfos.data() + index;
  107. }
  108. void Expr::_addLinkForInputs(EXPRP expr) {
  109. auto inputs = expr->inputs();
  110. for (int i=0; i<inputs.size(); ++i) {
  111. bool findEmpty = false;
  112. auto inputExpr = inputs[i]->mFrom;
  113. for (int j=0; j<inputExpr->mTo.size(); ++j) {
  114. auto ref = inputExpr->mTo[j].lock();
  115. if (nullptr == ref) {
  116. inputExpr->mTo[j] = WeakEXPRP(expr);
  117. findEmpty = true;
  118. break;
  119. }
  120. }
  121. if (!findEmpty) {
  122. inputExpr->mTo.emplace_back(WeakEXPRP(expr));
  123. }
  124. }
  125. }
  126. EXPRP Expr::create(Tensor* tensor, bool own) {
  127. EXPRP expr(new Expr(tensor, own));
  128. expr->mOp = nullptr;
  129. expr->mType = VARP::CONSTANT;
  130. auto& dstInfo = expr->mInside->mOutputInfos[0];
  131. expr->mInside->mInfoDirty = false;
  132. expr->mInside->mContentDirty = false;
  133. return expr;
  134. }
  135. EXPRP Expr::create(Variable::Info&& info, const void* ptr, VARP::InputType type, Expr::MemoryType memtype) {
  136. EXPRP expr(new Expr(1));
  137. expr->mOp = nullptr;
  138. auto originPtr = ptr;
  139. expr->mInside->mOutputInfos[0] = std::move(info);
  140. auto& dstInfo = expr->mInside->mOutputInfos[0];
  141. expr->mInside->mInfoDirty = false;
  142. dstInfo.syncSize();
  143. Utils::copyInfoToTensor(expr->mInside->mOutputTensors[0], expr->mInside->mOutputInfos.data());
  144. expr->mType = type;
  145. if (type == VARP::CONSTANT) {
  146. TensorUtils::getDescribe(expr->mInside->mOutputTensors[0])->usage = Tensor::InsideDescribe::CONSTANT;
  147. } else if (type == VARP::INPUT) {
  148. TensorUtils::getDescribe(expr->mInside->mOutputTensors[0])->usage = Tensor::InsideDescribe::INPUT;
  149. } else {
  150. // VARP::TRAINABLE
  151. TensorUtils::getDescribe(expr->mInside->mOutputTensors[0])->usage = Tensor::InsideDescribe::TRAINABLE;
  152. }
  153. if (dstInfo.size > 0 && memtype == COPY) {
  154. auto res = Utils::allocMemoryForHostTensor(expr->mInside->mOutputTensors[0]);
  155. if (!res) {
  156. MNN_ASSERT(false);
  157. return nullptr;
  158. }
  159. } else {
  160. expr->mInside->mOutputTensors[0]->buffer().host = nullptr;
  161. }
  162. if (nullptr == originPtr) {
  163. if (type == VARP::INPUT && dstInfo.size > 0) {
  164. expr->mInside->mContentDirty = true;
  165. }
  166. return expr;
  167. }
  168. expr->mInside->mContentDirty = false;
  169. if (memtype == COPY) {
  170. ::memcpy(expr->mInside->mOutputTensors[0]->buffer().host, originPtr, dstInfo.size * dstInfo.type.bytes());
  171. } else {
  172. expr->mInside->mOutputTensors[0]->buffer().host = (uint8_t*)originPtr;
  173. if (memtype == REF) {
  174. TensorUtils::getDescribe(expr->mInside->mOutputTensors[0])->memoryType = Tensor::InsideDescribe::MEMORY_OUTSIDE;
  175. }
  176. }
  177. return expr;
  178. }
  179. EXPRP Expr::create(std::shared_ptr<BufferStorage> extra, std::vector<VARP>&& inputs, int outputSize) {
  180. EXPRP expr(new Expr(outputSize));
  181. expr->mStorage = extra;
  182. expr->mOp = flatbuffers::GetRoot<Op>(extra->buffer());
  183. expr->mInputs = std::move(inputs);
  184. expr->mInside->mReq = ExecutorScope::Current()->getRequirement(expr.get());
  185. _addLinkForInputs(expr);
  186. return expr;
  187. }
  188. EXPRP Expr::create(const OpT* op, std::vector<VARP> inputs, int outputSize) {
  189. if (OpType_Input == op->type) {
  190. Variable::Info info;
  191. info.dim = op->main.AsInput()->dims;
  192. if (info.dim.size() >= 1 && -1 == info.dim[0]) {
  193. info.dim[0] = 1;
  194. }
  195. info.order = Utils::revertFormat(op->main.AsInput()->dformat);
  196. info.type = Utils::revertDataType(op->main.AsInput()->dtype);
  197. return create(std::move(info), nullptr, VARP::INPUT);
  198. }
  199. if (OpType_Const == op->type || OpType_TrainableParam == op->type) {
  200. Variable::Info info;
  201. info.dim = op->main.AsBlob()->dims;
  202. info.order = Utils::revertFormat(op->main.AsBlob()->dataFormat);
  203. void* ptr = nullptr;
  204. info.type = Utils::revertDataType(op->main.AsBlob()->dataType);
  205. info.syncSize();
  206. switch (op->main.AsBlob()->dataType) {
  207. case DataType_DT_INT8:
  208. ptr = (void*)op->main.AsBlob()->int8s.data();
  209. break;
  210. case DataType_DT_INT32:
  211. ptr = (void*)op->main.AsBlob()->int32s.data();
  212. break;
  213. case DataType_DT_UINT8:
  214. ptr = (void*)op->main.AsBlob()->uint8s.data();
  215. break;
  216. case DataType_DT_FLOAT:
  217. ptr = (void*)op->main.AsBlob()->float32s.data();
  218. break;
  219. default:
  220. break;
  221. }
  222. Expr::MemoryType memtype = Expr::MemoryType::COPY;
  223. if (op->main.AsBlob()->dataType == DataType_DT_HALF) {
  224. auto src = (half_float::half*)op->main.AsBlob()->uint8s.data();
  225. ptr = MNNMemoryAllocAlign(info.size * sizeof(float), MNN_MEMORY_ALIGN_DEFAULT);
  226. if (nullptr == src || nullptr == ptr) {
  227. EXPRP empty;
  228. return empty;
  229. }
  230. auto outputPtr = (float*)ptr;
  231. for (int i=0; i<info.size; ++i) {
  232. outputPtr[i] = src[i];
  233. }
  234. memtype = Expr::MemoryType::MOVE;
  235. }
  236. //MNN_ASSERT(nullptr != ptr);
  237. auto expr = create(std::move(info), ptr, VARP::CONSTANT, memtype);
  238. if (OpType_TrainableParam == op->type && nullptr != ptr) {
  239. expr->mType = VARP::TRAINABLE;
  240. }
  241. return expr;
  242. }
  243. flatbuffers::FlatBufferBuilder builder;
  244. auto offset = Op::Pack(builder, op);
  245. builder.Finish(offset);
  246. std::shared_ptr<BufferStorage> extra(new BufferStorage);
  247. extra->storage.reset(builder.ReleaseRaw(extra->allocated_size, extra->offset));
  248. auto resExpr = Expr::create(extra, std::move(inputs), outputSize);
  249. resExpr->setName(op->name);
  250. return resExpr;
  251. }
  252. void Expr::setName(const std::string& name) {
  253. mName = name;
  254. }
  255. bool Expr::requireInfo() {
  256. if (!mInside->mInfoDirty) {
  257. return true;
  258. }
  259. if (!mValid) {
  260. return false;
  261. }
  262. if (nullptr == mOp) {
  263. return !HasUnknownDim(mInside->mOutputInfos[0].dim);
  264. }
  265. bool ready = true;
  266. for (int i = 0; i < mInputs.size(); ++i) {
  267. if (nullptr == mInputs[i] || nullptr == mInputs[i]->mFrom) {
  268. // The Variable is set nullptr by api
  269. return false;
  270. }
  271. auto inputInfo = mInputs[i]->getInfo();
  272. if (nullptr == inputInfo) {
  273. #ifdef MNN_EXPRESS_ERROR_REPORT
  274. MNN_ERROR("%s, %d input not ready\n", mName.c_str(), i);
  275. #endif
  276. mValid = false;
  277. return false;
  278. }
  279. }
  280. for (int i = 0; i < mInputs.size(); ++i) {
  281. auto& v = mInputs[i];
  282. if (mInside->mReq.shapeNeedContent[i]) {
  283. // For shape need content, the content must not be nullptr
  284. auto ptr = v->readInternal(true);
  285. if (nullptr == ptr) {
  286. ready = false;
  287. break;
  288. }
  289. }
  290. }
  291. if (!ready) {
  292. return false;
  293. }
  294. //MNN_PRINT("Info %s, %p Start\n", mName.c_str(), this);
  295. auto res = ExecutorScope::Current()->computeInfo(this);
  296. //MNN_PRINT("Info Compute %s\n", mName.c_str());
  297. if (NO_ERROR == res) {
  298. mInside->mInfoDirty = false;
  299. } else {
  300. mValid = false;
  301. }
  302. return NO_ERROR == res;
  303. }
  304. size_t Variable::linkNumber() const {
  305. return mFrom->outputs().size();
  306. }
  307. const std::vector<WeakEXPRP>& Variable::toExprs() const {
  308. return mFrom->outputs();
  309. }
  310. VARP Variable::create(EXPRP expr, int index) {
  311. VARP res(new Variable(expr, index));
  312. #ifdef MNN_EXPR_SHAPE_EAGER
  313. auto info = expr->requireInfo();
  314. if (!info) {
  315. #ifdef MNN_EXPRESS_ERROR_REPORT
  316. MNN_ERROR("Can't compute shape\n");
  317. #endif
  318. }
  319. #endif
  320. return res;
  321. }
  322. void Expr::replace(EXPRP old, EXPRP from) {
  323. if (old.get() == from.get()) {
  324. return;
  325. }
  326. for (auto input : old->inputs()) {
  327. for (int j=0; j<input->mFrom->mTo.size(); ++j) {
  328. auto ref = input->mFrom->mTo[j].lock();
  329. if (ref.get() == old.get()) {
  330. input->mFrom->mTo[j].reset();
  331. }
  332. }
  333. }
  334. for (auto input : from->inputs()) {
  335. bool hasSet = false;
  336. for (int j=0; j<input->mFrom->mTo.size(); ++j) {
  337. auto ref = input->mFrom->mTo[j].lock();
  338. if (ref.get() == old.get()) {
  339. hasSet = true;
  340. break;
  341. }
  342. }
  343. if (!hasSet) {
  344. for (int j=0; j<input->mFrom->mTo.size(); ++j) {
  345. auto ref = input->mFrom->mTo[j].lock();
  346. if (nullptr == ref) {
  347. input->mFrom->mTo[j] = WeakEXPRP(old);
  348. hasSet = true;
  349. break;
  350. }
  351. }
  352. }
  353. if (!hasSet) {
  354. input->mFrom->mTo.emplace_back(WeakEXPRP(old));
  355. }
  356. }
  357. old->mOp = from->mOp;
  358. old->mName = from->mName;
  359. old->mOutputNames = from->mOutputNames;
  360. old->mStorage = from->mStorage;
  361. old->mType = from->mType;
  362. old->mValid = from->mValid;
  363. old->mInside = from->mInside;
  364. old->mInputs = from->mInputs;
  365. std::vector<Expr*> visited;
  366. old->visitOutputs([&](EXPRP expr, int index) {
  367. if (expr->visited()) {
  368. return false;
  369. }
  370. visited.emplace_back(expr.get());
  371. expr->setVisited(true);
  372. expr->mInside->mCache.reset();
  373. expr->mInside->mCacheOffset = 0;
  374. expr->mValid = true;
  375. expr->mInside->mInfoDirty = true;
  376. return true;
  377. });
  378. for (auto e : visited) {
  379. e->setVisited(false);
  380. }
  381. }
  382. void Variable::setName(const std::string& name) {
  383. mFrom->mOutputNames[mFromIndex] = name;
  384. if (mFrom->name().empty()) {
  385. mFrom->setName(name);
  386. }
  387. }
  388. const std::string& Variable::name() const {
  389. return mFrom->outputName(mFromIndex);
  390. }
  391. bool Variable::input(VARP src) {
  392. if (nullptr != mFrom->get() || VARP::CONSTANT == mFrom->mType) {
  393. MNN_ERROR("Can't input to no-input op\n");
  394. return false;
  395. }
  396. if (nullptr == src) {
  397. /*Close the Input*/
  398. mFrom->visitOutputs([](EXPRP expr, int index) {
  399. auto recurse = expr->mValid; expr->mValid = false;
  400. return recurse;
  401. });
  402. mFrom->mValid = false;
  403. return false;
  404. }
  405. auto info = src->getInfo();
  406. std::shared_ptr<Variable::Info> tempInfo;
  407. if (nullptr == info) {
  408. tempInfo.reset(new Variable::Info);
  409. tempInfo->size = 0;
  410. tempInfo->type = halide_type_of<float>();
  411. info = tempInfo.get();
  412. }
  413. auto dstInfo = getInfo();
  414. bool needChange = nullptr == dstInfo || info->order != dstInfo->order || info->dim.size() != dstInfo->dim.size() || info->type != dstInfo->type;
  415. if (!needChange) {
  416. for (int i=0; i<info->dim.size(); ++i) {
  417. if (dstInfo->dim[i] != info->dim[i]) {
  418. needChange = true;
  419. break;
  420. }
  421. }
  422. }
  423. if (!mFrom->mInside->mCache) {
  424. ExecutorScope::Current()->makeCache({mFrom}, false);
  425. }
  426. if (needChange) {
  427. mFrom->mInside->mOutputInfos[0] = *info;
  428. Utils::releaseMemoryForHostTensor(mFrom->inside()->mOutputTensors[0]);
  429. Utils::copyInfoToTensor(mFrom->inside()->mOutputTensors[0], mFrom->inside()->mOutputInfos.data());
  430. Utils::allocMemoryForHostTensor(mFrom->inside()->mOutputTensors[0]);
  431. }
  432. if (info->size) {
  433. auto dstPtr = writeInternal(false);
  434. auto srcPtr = src->readMap<void>();
  435. if (nullptr == dstPtr || nullptr == srcPtr) {
  436. //MNN_ERROR("Alloc memory error or compute src error in Variable::Input\n");
  437. return false;
  438. }
  439. ::memcpy(dstPtr, srcPtr, info->size * info->type.bytes());
  440. }
  441. if (needChange) {
  442. mFrom->visitOutputs([](EXPRP expr, int index) { return expr->setInfoDirty(); });
  443. } else {
  444. informDirty();
  445. }
  446. mFrom->mInside->mContentDirty = false;
  447. return true;
  448. }
  449. void Variable::replace(VARP dst, VARP src) {
  450. if (nullptr == src) {
  451. dst->setExpr(nullptr, 0);
  452. return;
  453. }
  454. if (nullptr == dst) {
  455. dst.mContent = src.mContent;
  456. return;
  457. }
  458. if (src->mFrom.get() == dst->mFrom.get()) {
  459. dst->mFromIndex = src->mFromIndex;
  460. return;
  461. }
  462. if (src->mFrom->outputSize() != dst->mFrom->outputSize()) {
  463. // Can't replace Expr, Just replace VARP
  464. std::vector<Expr*> visited;
  465. dst->mFrom->visitOutputs([src, dst, &visited](EXPRP expr, int index) {
  466. if (expr->visited()) {
  467. return false;
  468. }
  469. expr->setVisited(true);
  470. visited.emplace_back(expr.get());
  471. expr->mInside->mCache.reset();
  472. expr->mInside->mCacheOffset = 0;
  473. expr->mValid = true;
  474. expr->mInside->mInfoDirty = true;
  475. expr->mInside->mContentDirty = true;
  476. return true;
  477. });
  478. for (auto v : visited) {
  479. v->setVisited(false);
  480. }
  481. dst->mFrom->visitOutputs([src, dst](EXPRP expr, int index) {
  482. for (int i =0; i< expr->inputs().size(); ++i) {
  483. auto input = expr->inputs()[i];
  484. if (input == dst) {
  485. expr->mInputs[i] = src;
  486. }
  487. }
  488. src->mFrom->mTo.emplace_back(expr);
  489. return false;
  490. });
  491. dst->mFrom = src->mFrom;
  492. dst->mFromIndex = src->mFromIndex;
  493. return;
  494. }
  495. Expr::replace(dst->mFrom, src->mFrom);
  496. dst->mFromIndex = src->mFromIndex;
  497. }
  498. const Variable::Info* Variable::getInfo() {
  499. if (nullptr == mFrom) {
  500. return nullptr;
  501. }
  502. auto res = mFrom->requireInfo();
  503. if (!res) {
  504. return nullptr;
  505. }
  506. return mFrom->mInside->mOutputInfos.data() + mFromIndex;
  507. }
  508. bool Variable::resize(INTS dims) {
  509. if (nullptr != mFrom->get() && VARP::INPUT != mFrom->mType) {
  510. MNN_ERROR("Can't resize variable not from input\n");
  511. return false;
  512. }
  513. auto& info = mFrom->mInside->mOutputInfos[0];
  514. if (dims.size() == info.dim.size()) {
  515. bool theSame = true;
  516. for (int i=0; i<dims.size(); ++i) {
  517. if (info.dim[i] != dims[i]) {
  518. theSame = false;
  519. break;
  520. }
  521. }
  522. if (theSame) {
  523. return true;
  524. }
  525. }
  526. info.dim = dims;
  527. info.syncSize();
  528. Utils::copyInfoToTensor(mFrom->inside()->mOutputTensors[0], mFrom->inside()->mOutputInfos.data());
  529. Utils::releaseMemoryForHostTensor(mFrom->inside()->mOutputTensors[0]);
  530. if (0 >= info.size) {
  531. return false;
  532. }
  533. bool res = Utils::allocMemoryForHostTensor(mFrom->inside()->mOutputTensors[0]);
  534. if (!res) {
  535. return false;
  536. }
  537. mFrom->mValid = true;
  538. mFrom->inside()->mInfoDirty = false;
  539. mFrom->inside()->mContentDirty = true;
  540. mFrom->visitOutputs([](EXPRP expr, int index) { return expr->setInfoDirty(); });
  541. return true;
  542. }
  543. void Expr::visit(EXPRP expr, const std::function<bool(EXPRP)>& before, const std::function<bool(EXPRP)>& after) {
  544. bool next = before(expr);
  545. if (!next) {
  546. return;
  547. }
  548. for (int i = 0; i < expr->inputs().size(); ++i) {
  549. visit(expr->inputs()[i]->mFrom, before, after);
  550. }
  551. after(expr);
  552. }
  553. void* Variable::readInternal(bool forShape) {
  554. if (nullptr == mFrom->get()) {
  555. if (VARP::INPUT == mFrom->mType) {
  556. if (mFrom->mInside->mContentDirty) {
  557. return nullptr;
  558. }
  559. }
  560. //MNN_ASSERT(nullptr != mFrom->inside()->mOutputTensors[0]->buffer().host);
  561. auto inside = mFrom->inside();
  562. auto originTensor = inside->mOutputTensors[0];
  563. if (WrapExecution::needWrap(originTensor, nullptr)) {
  564. // For StaticModule will other-device runtime, we may create Variable with other-device's memory
  565. // The case won't occured for varibale = INPUT
  566. // Need Copy
  567. if (nullptr != inside->mHostTensor) {
  568. // The Varp will not be created as input, so we just need copy once
  569. return inside->mHostTensor->host<void>();
  570. }
  571. inside->mHostTensor = new Tensor;
  572. TensorUtils::copyShape(originTensor, inside->mHostTensor, true);
  573. inside->mHostTensor->buffer().type = originTensor->getType();
  574. inside->mHostTensor->buffer().host = (uint8_t*)MNNMemoryAllocAlign(inside->mHostTensor->size(), MNN_MEMORY_ALIGN_DEFAULT);
  575. TensorUtils::getDescribe(inside->mHostTensor)->memoryType = Tensor::InsideDescribe::MEMORY_HOST;
  576. originTensor->copyToHostTensor(inside->mHostTensor);
  577. return inside->mHostTensor->host<void>();
  578. }
  579. return originTensor->buffer().host;
  580. }
  581. auto res = mFrom->requireInfo();
  582. if (false == res) {
  583. return nullptr;
  584. }
  585. auto cache = mFrom->inside()->mCache;
  586. if (nullptr == cache) {
  587. ExecutorScope::Current()->makeCache({mFrom}, forShape);
  588. cache = mFrom->inside()->mCache;
  589. }
  590. if (nullptr == cache) {
  591. return nullptr;
  592. }
  593. if (NO_ERROR != ExecutorScope::Current()->runCache(cache)) {
  594. return nullptr;
  595. }
  596. return Executor::mapOutput(cache.get(), mFrom->mInside->mCacheOffset + mFromIndex, mFrom->mInside->mOutputTensors[mFromIndex]);
  597. }
  598. void Variable::informDirty() {
  599. std::vector<Expr*> visited;
  600. mFrom->visitOutputs([&visited](EXPRP expr, int index) {
  601. if (expr->visited()) {
  602. return false;
  603. }
  604. visited.emplace_back(expr.get());
  605. expr->setVisited(true);
  606. if (expr->inside()->mReq.shapeNeedContent.empty()) {
  607. // Not init
  608. return false;
  609. }
  610. if (expr->inside()->mReq.shapeNeedContent[index]) {
  611. expr->setInfoDirty();
  612. expr->visitOutputs([](EXPRP e, int index) { return e->setInfoDirty(); });
  613. return false;
  614. }
  615. if (expr->inside()->mReq.contentNeedContent[index]) {
  616. if (expr->inside()->mCache != nullptr) {
  617. Executor::setContentDirty(expr->inside()->mCache.get());
  618. }
  619. return true;
  620. }
  621. return false;
  622. });
  623. for (auto e : visited) {
  624. e->setVisited(false);
  625. }
  626. }
  627. void Variable::prepareCompute(const std::vector<VARP>& vars, bool forceCpu) {
  628. std::vector<EXPRP> exprs;
  629. for (auto v : vars) {
  630. if (!v->expr().first->visited()) {
  631. v->expr().first->inside()->mCache = nullptr;
  632. v->expr().first->requireInfo();
  633. v->expr().first->setVisited(true);
  634. exprs.emplace_back(v->expr().first);
  635. }
  636. }
  637. for (auto v : vars) {
  638. v->expr().first->setVisited(false);
  639. }
  640. ExecutorScope::Current()->makeCache(std::move(exprs), forceCpu);
  641. }
  642. void Variable::compute(const std::vector<VARP>& vars, bool forceCPU) {
  643. prepareCompute(vars, forceCPU);
  644. for (auto& v : vars) {
  645. if (nullptr != v->mFrom) {
  646. auto inside = v->mFrom->inside();
  647. if (nullptr != inside && nullptr != inside->mCache) {
  648. ExecutorScope::Current()->runCache(inside->mCache);
  649. }
  650. }
  651. }
  652. }
  653. void* Variable::writeInternal(bool inform) {
  654. if (nullptr != mFrom->get()) {
  655. return nullptr;
  656. }
  657. if (inform) {
  658. informDirty();
  659. }
  660. mFrom->mInside->mContentDirty = false;
  661. return mFrom->inside()->mOutputTensors[0]->host<void>();
  662. }
  663. void Variable::unMap() {
  664. //mFrom->inside()->onUnMapContent(mFromIndex);
  665. }
  666. void Expr::visitOutputs(const std::function<bool(EXPRP, int)>& visit) {
  667. for (auto iter = mTo.begin(); iter != mTo.end();) {
  668. auto expr = iter->lock();
  669. if (nullptr == expr) {
  670. iter = mTo.erase(iter);
  671. continue;
  672. }
  673. bool recurse = false;
  674. auto inputs = expr->inputs();
  675. for (int i=0; i<inputs.size(); ++i) {
  676. if (inputs[i]->mFrom.get() == this) {
  677. recurse = recurse || visit(expr, i);
  678. }
  679. }
  680. if (recurse) {
  681. expr->visitOutputs(visit);
  682. }
  683. iter++;
  684. }
  685. }
  686. bool Expr::setInfoDirty() {
  687. if (mInside->mInfoDirty && mValid) {
  688. //MNN_PRINT("End Info Dirty for %s\n", mName.c_str());
  689. return false;
  690. }
  691. //MNN_PRINT("Set Info Dirty for %s\n", mName.c_str());
  692. mInside->mInfoDirty = true;
  693. mInside->mContentDirty = true;
  694. mValid = true;
  695. if (mInside->mCache != nullptr) {
  696. Executor::setShapeDirty(mInside->mCache.get());
  697. }
  698. for (auto o : mInside->mOutputTensors) {
  699. Utils::releaseMemoryForHostTensor(o);
  700. }
  701. return true;
  702. }
  703. std::vector<VARP> Variable::load(const char* fileName) {
  704. AutoStorage<uint8_t> buffer;
  705. {
  706. FileLoader loader(fileName);
  707. if (!loader.valid()) {
  708. MNN_ERROR("Error for open %s\n", fileName);
  709. return {};
  710. }
  711. loader.read();
  712. if (!loader.valid()) {
  713. return {};
  714. }
  715. loader.merge(buffer);
  716. if (buffer.get() == nullptr) {
  717. return {};
  718. }
  719. }
  720. return load(buffer.get(), buffer.size());
  721. }
  722. std::vector<VARP> Variable::load(const uint8_t* buffer, size_t length) {
  723. AUTOTIME;
  724. flatbuffers::Verifier verify((const uint8_t*)(buffer), length);
  725. if (false == VerifyNetBuffer(verify)) {
  726. MNN_PRINT("Invalidate buffer to create variable\n");
  727. return {};
  728. }
  729. std::unique_ptr<NetT> source(UnPackNet(buffer));
  730. if (nullptr == source) {
  731. return {};
  732. }
  733. if (source->oplists.empty()) {
  734. MNN_ERROR("Invalid net\n");
  735. return {};
  736. }
  737. // FUNC_PRINT(source->oplists.size());
  738. auto opSize = source->oplists.size();
  739. auto tensorCount = source->tensorName.size();
  740. if (tensorCount == 0) {
  741. tensorCount = source->tensorNumber;
  742. }
  743. std::vector<VARP> variable;
  744. variable.reserve(tensorCount);
  745. std::map<int, VARP> variableMap;
  746. // Generate All Exprs by order of net
  747. for (int i = 0; i < opSize; ++i) {
  748. std::vector<VARP> inputs;
  749. auto op = source->oplists[i].get();
  750. for (int index = 0; index < op->inputIndexes.size(); ++index) {
  751. auto inputIndex = op->inputIndexes[index];
  752. if (variableMap.find(inputIndex) == variableMap.end()) {
  753. MNN_ERROR("Can't find variable for %s, the graph is error\n", op->name.c_str());
  754. break;
  755. }
  756. inputs.emplace_back(variableMap[inputIndex]);
  757. }
  758. EXPRP expr = Expr::create(source->oplists[i].get(), inputs, (int)op->outputIndexes.size());
  759. expr->setName(source->oplists[i]->name);
  760. for (int index = 0; index < op->outputIndexes.size(); ++index) {
  761. auto outputIndex = op->outputIndexes[index];
  762. if (variableMap.find(outputIndex) == variableMap.end()) {
  763. auto newVariable = Variable::create(expr, index);
  764. if (source->tensorName.size() > outputIndex) {
  765. newVariable->setName(source->tensorName[outputIndex]);
  766. }
  767. variableMap[outputIndex] = newVariable;
  768. variable.emplace_back(newVariable);
  769. }
  770. }
  771. }
  772. return variable;
  773. }
  774. std::map<std::string, VARP> Variable::loadMap(const uint8_t* buffer, size_t length) {
  775. AUTOTIME;
  776. auto variables = load(buffer, length);
  777. std::map<std::string, VARP> varMap;
  778. for (auto v : variables) {
  779. varMap[v->name()] = v;
  780. }
  781. return varMap;
  782. }
  783. std::map<std::string, VARP> Variable::loadMap(const char* fileName) {
  784. AUTOTIME;
  785. auto variables = load(fileName);
  786. std::map<std::string, VARP> varMap;
  787. for (auto v : variables) {
  788. varMap[v->name()] = v;
  789. }
  790. return varMap;
  791. }
  792. std::vector<VARP> Variable::mapToSequence(const std::map<std::string, VARP>& source) {
  793. std::vector<VARP> outputs;
  794. outputs.reserve(source.size());
  795. for (auto& iter : source) {
  796. outputs.emplace_back(iter.second);
  797. }
  798. return outputs;
  799. }
  800. void Variable::save(const std::vector<VARP>& vars, NetT* dest) {
  801. auto executeOrder = getExecuteOrder(vars);
  802. // Get Expr - TensorOffset Map
  803. std::map<EXPRP, int> varIndexInfo;
  804. {
  805. int tensorOffset = 0;
  806. for (int i=0; i<executeOrder.size(); ++i) {
  807. auto expr = executeOrder[i];
  808. auto outputSize = executeOrder[i]->outputSize();
  809. varIndexInfo[expr] = tensorOffset;
  810. tensorOffset += outputSize;
  811. }
  812. dest->tensorName.resize(tensorOffset);
  813. }
  814. // Create All Op
  815. for (int index = 0; index < executeOrder.size(); ++index) {
  816. auto expr = executeOrder[index];
  817. auto mOp = expr->get();
  818. std::unique_ptr<OpT> op;
  819. if (nullptr != mOp) {
  820. op.reset(mOp->UnPack());
  821. } else {
  822. MNN_ASSERT(1 == expr->outputSize());
  823. auto& info = expr->mInside->mOutputInfos[0];
  824. const void* ptr = expr->mInside->mOutputTensors[0]->host<void>();
  825. VARP temp;
  826. if (nullptr == ptr || expr->mInside->mOutputTensors[0]->deviceId() > 0) {
  827. temp = Variable::create(expr);
  828. ptr = temp->readMap<void>();
  829. }
  830. op.reset(new OpT);
  831. if (expr->mType != VARP::INPUT) {
  832. auto blob = new BlobT;
  833. blob->dataFormat = (MNN_DATA_FORMAT)Utils::convertFormat(info.order);
  834. blob->dims = info.dim;
  835. if (info.type.code == halide_type_float) {
  836. blob->dataType = DataType_DT_FLOAT;
  837. blob->float32s.resize(info.size);
  838. ::memcpy(blob->float32s.data(), ptr, info.size * sizeof(float));
  839. } else if (info.type.code == halide_type_int && info.type.bits == 32) {
  840. blob->dataType = DataType_DT_INT32;
  841. blob->int32s.resize(info.size);
  842. ::memcpy(blob->int32s.data(), ptr, info.size * sizeof(int));
  843. } else if (info.type.code == halide_type_int && info.type.bits == 8) {
  844. blob->dataType = DataType_DT_INT8;
  845. blob->int8s.resize(info.size);
  846. ::memcpy(blob->int8s.data(), ptr, info.size * sizeof(int8_t));
  847. } else if (info.type.code == halide_type_uint && info.type.bits == 8) {
  848. blob->dataType = DataType_DT_UINT8;
  849. blob->uint8s.resize(info.size);
  850. ::memcpy(blob->uint8s.data(), ptr, info.size * sizeof(uint8_t));
  851. }
  852. op->type = OpType_Const;
  853. if (expr->mType == VARP::TRAINABLE) {
  854. op->type = OpType_TrainableParam;
  855. }
  856. op->main.type = OpParameter_Blob;
  857. op->main.value = blob;
  858. } else {
  859. op->type = OpType_Input;
  860. op->main.type = OpParameter_Input;
  861. op->main.value = new InputT;
  862. op->main.AsInput()->dtype = (MNN::DataType)Utils::convertDataType(info.type);
  863. MNN_ASSERT(op->main.AsInput()->dtype != DataType_DT_INVALID);
  864. op->main.AsInput()->dims = info.dim;
  865. op->main.AsInput()->dformat = (MNN_DATA_FORMAT)Utils::convertFormat(info.order);
  866. }
  867. }
  868. op->name = expr->name();
  869. op->inputIndexes.resize(expr->inputs().size());
  870. for (int i = 0; i < op->inputIndexes.size(); ++i) {
  871. auto inputExpr = expr->inputs()[i]->expr();
  872. op->inputIndexes[i] = varIndexInfo[inputExpr.first] + inputExpr.second;
  873. }
  874. if (op->name.empty()) {
  875. op->name = EnumNameOpType(op->type) + numberToString(index+1);
  876. }
  877. op->outputIndexes.resize(expr->outputSize());
  878. auto tensorIndexOffset = varIndexInfo[expr];
  879. for (int v=0; v<expr->outputSize(); ++v) {
  880. op->outputIndexes[v] = tensorIndexOffset + v;
  881. dest->tensorName[tensorIndexOffset+v] = expr->outputName(v);
  882. }
  883. dest->oplists.emplace_back(std::move(op));
  884. }
  885. // Fill Empty Tensor Name With Default Op Name
  886. for (int index = 0; index < executeOrder.size(); ++index) {
  887. auto expr = executeOrder[index];
  888. auto op = dest->oplists[index].get();
  889. auto tensorIndexOffset = varIndexInfo[expr];
  890. for (int v=0; v<expr->outputSize(); ++v) {
  891. auto subindex = tensorIndexOffset + v;
  892. if (dest->tensorName[subindex].empty()) {
  893. if (v == 0) {
  894. dest->tensorName[subindex] = op->name;
  895. } else {
  896. dest->tensorName[subindex] = op->name + numberToString(v);
  897. }
  898. }
  899. }
  900. }
  901. }
  902. void Variable::save(const std::vector<VARP>& vars, const char* fileName) {
  903. std::unique_ptr<NetT> net(new NetT);
  904. save(vars, net.get());
  905. // FUNC_PRINT(net->oplists.size());
  906. flatbuffers::FlatBufferBuilder builder(1024);
  907. auto offset = Net::Pack(builder, net.get());
  908. builder.Finish(offset);
  909. // TODO, use FileWriter instead
  910. FILE* f = fopen(fileName, "wb");
  911. if (nullptr == f) {
  912. MNN_ERROR("Open %s error\n", fileName);
  913. return;
  914. }
  915. static const size_t block = 4096;
  916. size_t totalSize = builder.GetSize();
  917. size_t blockSize = UP_DIV(totalSize, block);
  918. for (size_t i = 0; i < blockSize; ++i) {
  919. size_t sta = block * i;
  920. size_t fin = std::min(sta + block, totalSize);
  921. if (fin > sta) {
  922. auto realSize = fwrite((const char*)builder.GetBufferPointer() + sta, 1, fin - sta, f);
  923. if (realSize != fin - sta) {
  924. MNN_ERROR("Write %s error\n", fileName);
  925. }
  926. }
  927. }
  928. fclose(f);
  929. }
  930. std::pair<std::map<std::string, VARP>, std::map<std::string, VARP>> Variable::getInputAndOutput(const std::map<std::string, VARP>& allVariable) {
  931. std::pair<std::map<std::string, VARP>, std::map<std::string, VARP>> res;
  932. for (auto& iter : allVariable) {
  933. auto var = iter.second;
  934. if (var->expr().first->get() == nullptr && var->expr().first->mType == VARP::INPUT) {
  935. res.first[var->name()] = var;
  936. }
  937. if (var->linkNumber() == 0) {
  938. res.second[var->name()] = var;
  939. }
  940. }
  941. return res;
  942. }
  943. std::vector<EXPRP> Variable::getExecuteOrder(const std::vector<VARP>& outputs) {
  944. std::vector<EXPRP> sequence;
  945. for (auto output : outputs) {
  946. Expr::visit(
  947. output->mFrom, [](EXPRP expr) { return !expr->visited(); },
  948. [&sequence](EXPRP expr) {
  949. //FUNC_PRINT_ALL(var->name().c_str(), s);
  950. if (!expr->visited()) {
  951. sequence.emplace_back(expr);
  952. expr->setVisited(true);
  953. }
  954. return true;
  955. });
  956. }
  957. for (auto expr : sequence) {
  958. expr->setVisited(false);
  959. }
  960. return sequence;
  961. }
  962. VARP VARP::operator+(VARP var) const {
  963. return _Add(VARP(mContent), var);
  964. }
  965. VARP VARP::operator-(VARP var) const {
  966. return _Subtract(VARP(mContent), var);
  967. }
  968. VARP VARP::operator*(VARP var) const {
  969. return _Multiply(VARP(mContent), var);
  970. }
  971. VARP VARP::operator/(VARP var) const {
  972. return _Divide(VARP(mContent), var);
  973. }
  974. VARP VARP::mean(INTS dims) const {
  975. return _ReduceMean(VARP(mContent), dims);
  976. }
  977. VARP VARP::sum(INTS dims) const {
  978. return _ReduceSum(VARP(mContent), dims);
  979. }
  980. } // namespace Express
  981. } // namespace MNN