|
@@ -329,200 +329,337 @@ void Transpose(uint8_t* output, const uint8_t* input, const TransposeParam* cpuP
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// for the following transpose kernels:
|
|
|
+// maxCount is num of threads i.e., num of elements of output format
|
|
|
+// inChannelPack is num of channel pack of input format
|
|
|
+// divOutChannelPack is Div for channel pack of output format
|
|
|
+
|
|
|
+// copy kernel
|
|
|
template<typename T0, typename T1>
|
|
|
-__global__ void NCHW_2_NHWC8(const T0* input,
|
|
|
- T1* output,
|
|
|
- const int maxCount,
|
|
|
- const int channel,
|
|
|
- const int area,
|
|
|
- const int channel_pack,
|
|
|
- DivModFast d_ocp,
|
|
|
- DivModFast d_area
|
|
|
+__global__ void NCHW_2_NCHW(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount
|
|
|
) {
|
|
|
for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
- int area_idx, temp, chnlp_idx, batch_idx;
|
|
|
- d_ocp.divmod(index, temp, chnlp_idx);
|
|
|
- d_area.divmod(temp, batch_idx, area_idx);
|
|
|
+ output[index] = (T1)input[index];
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- if(chnlp_idx >= channel) {
|
|
|
- output[index] = (T1)0.0f;
|
|
|
- continue;
|
|
|
- }
|
|
|
- int src_offset = (batch_idx * channel + chnlp_idx) * area + area_idx;
|
|
|
+// NHWC NCHW
|
|
|
+template<typename T0, typename T1>
|
|
|
+__global__ void NHWC_2_NCHW(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel, // redundant parameter
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
+) {
|
|
|
+ for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divArea.divmod(index, temp, area_idx);
|
|
|
+ divOutChannelPack.divmod(temp, batch_idx, chnl_idx);
|
|
|
+
|
|
|
+ int src_offset = (batch_idx * area + area_idx) * inChannelPack+ chnl_idx;
|
|
|
+ output[index] = (T1)input[src_offset];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// NHWC8_2_NCHW
|
|
|
+template<typename T0, typename T1>
|
|
|
+__global__ void NHWC8_2_NCHW(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel, // redundant parameter
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
+) {
|
|
|
+ for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divArea.divmod(index, temp, area_idx);
|
|
|
+ divOutChannelPack.divmod(temp, batch_idx, chnl_idx);
|
|
|
+
|
|
|
+ int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;
|
|
|
output[index] = (T1)input[src_offset];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// C4NHW4_2_NCHW
|
|
|
+template<typename T0, typename T1>
|
|
|
+__global__ void C4NHW4_2_NCHW(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel,
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack, // redundant parameter
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
+) {
|
|
|
+ const int batch = (maxCount / channel) / area;
|
|
|
+ for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divArea.divmod(index, temp, area_idx);
|
|
|
+ divOutChannelPack.divmod(temp, batch_idx, chnl_idx);
|
|
|
+
|
|
|
+ int c4_idx = chnl_idx >> 2;
|
|
|
+ int cL_idx = chnl_idx & 3;
|
|
|
+ int src_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
|
|
|
+ output[index] = (T1)input[src_offset];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// NCHW NHWC
|
|
|
template<typename T0, typename T1>
|
|
|
__global__ void NCHW_2_NHWC(const T0* input,
|
|
|
- T1* output,
|
|
|
- const int maxCount,
|
|
|
- const int channel,
|
|
|
- const int area,
|
|
|
- const int channel_pack,
|
|
|
- DivModFast d_oc,
|
|
|
- DivModFast d_area
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel, // redundant parameter
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
) {
|
|
|
for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
int area_idx, temp, chnl_idx, batch_idx;
|
|
|
- d_oc.divmod(index, temp, chnl_idx);
|
|
|
- d_area.divmod(temp, batch_idx, area_idx);
|
|
|
-
|
|
|
- int src_offset = (batch_idx * channel + chnl_idx) * area + area_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
+
|
|
|
+ int src_offset = (batch_idx * inChannelPack + chnl_idx) * area + area_idx;
|
|
|
+ output[index] = (T1)input[src_offset];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// NHWC8 NHWC
|
|
|
+template<typename T0, typename T1>
|
|
|
+__global__ void NHWC8_2_NHWC(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel, // redundant parameter
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
+) {
|
|
|
+ for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
+
|
|
|
+ int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;
|
|
|
+ output[index] = (T1)input[src_offset];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// C4NHW4 NHWC
|
|
|
+template<typename T0, typename T1>
|
|
|
+__global__ void C4NHW4_2_NHWC(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel,
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack, // redundant parameter
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
+) {
|
|
|
+ const int batch = (maxCount / channel) / area;
|
|
|
+ for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
+
|
|
|
+ int c4_idx = chnl_idx >> 2;
|
|
|
+ int cL_idx = chnl_idx & 3;
|
|
|
+ int src_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
|
|
|
output[index] = (T1)input[src_offset];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// NHWC NHWC8
|
|
|
template<typename T0, typename T1>
|
|
|
__global__ void NHWC_2_NHWC8(const T0* input,
|
|
|
- T1* output,
|
|
|
- const int maxCount,
|
|
|
- const int channel,
|
|
|
- const int area,
|
|
|
- const int channel_pack,
|
|
|
- DivModFast d_ocp,
|
|
|
- DivModFast d_area
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel,
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
) {
|
|
|
for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
- int area_idx, temp, chnlp_idx, batch_idx;
|
|
|
- d_ocp.divmod(index, temp, chnlp_idx);
|
|
|
- d_area.divmod(temp, batch_idx, area_idx);
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
|
|
|
- if(chnlp_idx >= channel) {
|
|
|
+ if(chnl_idx >= channel) {
|
|
|
output[index] = (T1)0.0f;
|
|
|
continue;
|
|
|
}
|
|
|
- int src_offset = (batch_idx * area + area_idx) * channel + chnlp_idx;
|
|
|
+
|
|
|
+ int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;
|
|
|
output[index] = (T1)input[src_offset];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// NCHW NHWC8
|
|
|
template<typename T0, typename T1>
|
|
|
-__global__ void NHWC8_2_NCHW(const T0* input,
|
|
|
- T1* output,
|
|
|
- const int maxCount,
|
|
|
- const int channel,
|
|
|
- const int area,
|
|
|
- const int channel_pack,
|
|
|
- DivModFast d_oc,
|
|
|
- DivModFast d_area
|
|
|
+__global__ void NCHW_2_NHWC8(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel,
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
) {
|
|
|
for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
|
|
|
- int area_idx, temp, channel_idx, batch_idx;
|
|
|
- d_area.divmod(index, temp, area_idx);
|
|
|
- d_oc.divmod(temp, batch_idx, channel_idx);
|
|
|
+ if(chnl_idx >= channel) {
|
|
|
+ output[index] = (T1)0.0f;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- int src_offset = (batch_idx * area + area_idx) * channel_pack + channel_idx;
|
|
|
+ int src_offset = (batch_idx * inChannelPack + chnl_idx) * area + area_idx;
|
|
|
output[index] = (T1)input[src_offset];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// C4NHW4 NHWC8
|
|
|
template<typename T0, typename T1>
|
|
|
-__global__ void NHWC8_2_NHWC(const T0* input,
|
|
|
- T1* output,
|
|
|
- const int maxCount,
|
|
|
- const int channel,
|
|
|
- const int area,
|
|
|
- const int channel_pack,
|
|
|
- DivModFast d_oc,
|
|
|
- DivModFast d_area
|
|
|
+__global__ void C4NHW4_2_NHWC8(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel,
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack, // redundant parameter
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
) {
|
|
|
+ const int batch = (maxCount / (UP_DIV(channel, 8) * 8)) / area;
|
|
|
for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
|
|
|
- int area_idx, temp, channel_idx, batch_idx;
|
|
|
- d_oc.divmod(index, temp, channel_idx);
|
|
|
- d_area.divmod(temp, batch_idx, area_idx);
|
|
|
+ if(chnl_idx >= channel) {
|
|
|
+ output[index] = (T1)0.0f;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- int src_offset = (batch_idx * area + area_idx) * channel_pack + channel_idx;
|
|
|
+ int c4_idx = chnl_idx >> 2;
|
|
|
+ int cL_idx = chnl_idx & 3;
|
|
|
+ int src_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
|
|
|
output[index] = (T1)input[src_offset];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// NHWC_2_C4NHW4
|
|
|
template<typename T0, typename T1>
|
|
|
-__global__ void NCHW_2_NCHW(const T0* input,
|
|
|
- T1* output,
|
|
|
- const int maxCount
|
|
|
+__global__ void NHWC_2_C4NHW4(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel,
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
) {
|
|
|
+ const int batch = (maxCount / (UP_DIV(channel, 4) * 4)) / area;
|
|
|
for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
- output[index] = (T1)input[index];
|
|
|
+ // arrange threads arrodring to NHWC4 format
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
+
|
|
|
+ int c4_idx = chnl_idx >> 2; // chnl_idx / 4
|
|
|
+ int cL_idx = chnl_idx & 3; // chnl_idx % 4
|
|
|
+ int dst_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
|
|
|
+ int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;
|
|
|
+
|
|
|
+ if (chnl_idx >= channel) {
|
|
|
+ output[dst_offset] = (T1)0.0f;;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ output[dst_offset] = (T1)input[src_offset];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// NCHW C4NHW4
|
|
|
template<typename T0, typename T1>
|
|
|
-__global__ void C4NHW4_2_NHWC8(const T0* input,
|
|
|
- T1* output,
|
|
|
- const int maxCount,
|
|
|
- const int batch,
|
|
|
- const int area,
|
|
|
- const int channel,
|
|
|
- const int channel_pack
|
|
|
+__global__ void NCHW_2_C4NHW4(const T0* input,
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel,
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
) {
|
|
|
+ const int batch = (maxCount / (UP_DIV(channel, 4) * 4)) / area;
|
|
|
for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
- int c_idx = index % channel_pack;
|
|
|
- int temp = index / channel_pack;
|
|
|
- int hw_idx = temp % area;
|
|
|
- int batch_idx = temp / area;
|
|
|
-
|
|
|
- if(c_idx >= channel) {
|
|
|
- output[index] = (T1)0.0f;
|
|
|
- continue;
|
|
|
- }
|
|
|
- int c4_idx = c_idx >> 2;
|
|
|
- int cL_idx = c_idx & 3;
|
|
|
- output[index] = (T1)input[((c4_idx * batch + batch_idx) * area + hw_idx) * 4 + cL_idx];
|
|
|
+ // arrange threads arrodring to NHWC4 format
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
+
|
|
|
+ int c4_idx = chnl_idx >> 2; // chnl_idx / 4
|
|
|
+ int cL_idx = chnl_idx & 3; // chnl_idx % 4
|
|
|
+ int dst_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
|
|
|
+ int src_offset = (batch_idx * inChannelPack + chnl_idx) * area + area_idx;
|
|
|
+
|
|
|
+ if (chnl_idx >= channel) {
|
|
|
+ output[dst_offset] = (T1)0.0f;;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ output[dst_offset] = (T1)input[src_offset];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// NHWC8 C4NHW4
|
|
|
template<typename T0, typename T1>
|
|
|
__global__ void NHWC8_2_C4NHW4(const T0* input,
|
|
|
- T1* output,
|
|
|
- const int maxCount,
|
|
|
- const int batch,
|
|
|
- const int channel,
|
|
|
- const int area,
|
|
|
- const int channel_pack
|
|
|
+ T1* output,
|
|
|
+ const int maxCount,
|
|
|
+ const int channel,
|
|
|
+ const int area,
|
|
|
+ const int inChannelPack,
|
|
|
+ DivModFast divOutChannelPack,
|
|
|
+ DivModFast divArea
|
|
|
) {
|
|
|
+ const int batch = (maxCount / (UP_DIV(channel, 4) * 4)) / area;
|
|
|
for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
- int c_idx = index % channel_pack;
|
|
|
- int temp = index / channel_pack;
|
|
|
- int hw_idx = temp % area;
|
|
|
- int batch_idx = temp / area;
|
|
|
+ // arrange threads arrodring to NHWC4 format
|
|
|
+ int area_idx, temp, chnl_idx, batch_idx;
|
|
|
+ divOutChannelPack.divmod(index, temp, chnl_idx);
|
|
|
+ divArea.divmod(temp, batch_idx, area_idx);
|
|
|
+
|
|
|
+ int c4_idx = chnl_idx >> 2; // chnl_idx / 4
|
|
|
+ int cL_idx = chnl_idx & 3; // chnl_idx % 4
|
|
|
+ int dst_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
|
|
|
+ int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;;
|
|
|
|
|
|
- int channel_8 = ((channel + 7) / 8) * 8;
|
|
|
- int c4_idx = c_idx >> 2;
|
|
|
- int cL_idx = c_idx & 3;
|
|
|
- output[((c4_idx * batch + batch_idx) * area + hw_idx) * 4 + cL_idx] =
|
|
|
- (T1)input[(batch_idx * area + hw_idx) * channel_8 + c_idx];
|
|
|
+ output[dst_offset] = (T1)input[src_offset];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
template<class T0, class T1>
|
|
|
static void insideFormatConvert(T0* input, T1* output, MNN_DATA_FORMAT srcDataFormat, MNN_DATA_FORMAT dstDataFormat, CUDARuntime* runtime, \
|
|
|
- const int area, const int batch, const int channel) {
|
|
|
+ const int area, const int batch, const int channel, const bool srcDevice, const bool dstDevice) {
|
|
|
DivModFast d_oc(channel);
|
|
|
- DivModFast d_ocp(UP_DIV(channel, 8) * 8);
|
|
|
+ DivModFast d_oc4(UP_DIV(channel, 4) * 4);
|
|
|
+ DivModFast d_oc8(UP_DIV(channel, 8) * 8);
|
|
|
DivModFast d_area(area);
|
|
|
|
|
|
- if(srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
- const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
|
|
|
- const int block_num = runtime->blocks_num(maxCount);
|
|
|
- const int block_size = runtime->threads_num();
|
|
|
- NCHW_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
|
|
|
- d_ocp, d_area);
|
|
|
- checkKernelErrors;
|
|
|
- return;
|
|
|
- }
|
|
|
- if(srcDataFormat == MNN_DATA_FORMAT_NHWC && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
- const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
|
|
|
- const int block_num = runtime->blocks_num(maxCount);
|
|
|
- const int block_size = runtime->threads_num();
|
|
|
- NHWC_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
|
|
|
- d_ocp, d_area);
|
|
|
- checkKernelErrors;
|
|
|
- return;
|
|
|
- }
|
|
|
- if((srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NCHW) || \
|
|
|
+ // NCHW NCHW
|
|
|
+ // NHWC NHWC
|
|
|
+ if ((srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NCHW) || \
|
|
|
(srcDataFormat == MNN_DATA_FORMAT_NHWC && dstDataFormat == MNN_DATA_FORMAT_NHWC)) {
|
|
|
const int maxCount = batch * area * channel;
|
|
|
const int block_num = runtime->blocks_num(maxCount);
|
|
@@ -531,168 +668,178 @@ static void insideFormatConvert(T0* input, T1* output, MNN_DATA_FORMAT srcDataFo
|
|
|
checkKernelErrors;
|
|
|
return;
|
|
|
}
|
|
|
- if(srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NCHW) {
|
|
|
- const int maxCount = batch * area * channel;
|
|
|
- const int block_num = runtime->blocks_num(maxCount);
|
|
|
- const int block_size = runtime->threads_num();
|
|
|
- NHWC8_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
|
|
|
- d_oc, d_area);
|
|
|
- checkKernelErrors;
|
|
|
+
|
|
|
+ // NC4HW4 NC4HW4
|
|
|
+ if (srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
+ if(!srcDevice && dstDevice) {
|
|
|
+ const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
|
|
|
+ const int block_num = runtime->blocks_num(maxCount);
|
|
|
+ const int block_size = runtime->threads_num();
|
|
|
+ C4NHW4_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 4) * 4, d_oc8, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
+ } else if (srcDevice && !dstDevice) {
|
|
|
+ const int maxCount = batch * area * UP_DIV(channel, 4) * 4;
|
|
|
+ const int block_num = runtime->blocks_num(maxCount);
|
|
|
+ const int block_size = runtime->threads_num();
|
|
|
+ NHWC8_2_C4NHW4<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8, d_oc4, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
+ } else {
|
|
|
+ const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
|
|
|
+ const int block_num = runtime->blocks_num(maxCount);
|
|
|
+ const int block_size = runtime->threads_num();
|
|
|
+ NCHW_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount);
|
|
|
+ checkKernelErrors;
|
|
|
+ }
|
|
|
return;
|
|
|
}
|
|
|
- if(srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NHWC) {
|
|
|
+
|
|
|
+ // NHWC NCHW
|
|
|
+ if (srcDataFormat == MNN_DATA_FORMAT_NHWC && dstDataFormat == MNN_DATA_FORMAT_NCHW) {
|
|
|
const int maxCount = batch * area * channel;
|
|
|
const int block_num = runtime->blocks_num(maxCount);
|
|
|
const int block_size = runtime->threads_num();
|
|
|
- NHWC8_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
|
|
|
- d_oc, d_area);
|
|
|
+ NHWC_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc, d_area);
|
|
|
checkKernelErrors;
|
|
|
return;
|
|
|
}
|
|
|
- if(srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NHWC) {
|
|
|
+
|
|
|
+ // NC4HW4 NCHW
|
|
|
+ if (srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NCHW) {
|
|
|
+ if (!srcDevice) {
|
|
|
+ const int maxCount = batch * area * channel;
|
|
|
+ const int block_num = runtime->blocks_num(maxCount);
|
|
|
+ const int block_size = runtime->threads_num();
|
|
|
+ C4NHW4_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 4) * 4, d_oc, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
+ } else {
|
|
|
+ const int maxCount = batch * area * channel;
|
|
|
+ const int block_num = runtime->blocks_num(maxCount);
|
|
|
+ const int block_size = runtime->threads_num();
|
|
|
+ NHWC8_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8, d_oc, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
+ }
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // NCHW NHWC
|
|
|
+ if (srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NHWC) {
|
|
|
const int maxCount = batch * area * channel;
|
|
|
const int block_num = runtime->blocks_num(maxCount);
|
|
|
const int block_size = runtime->threads_num();
|
|
|
- NCHW_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
|
|
|
- d_oc, d_area);
|
|
|
+ NCHW_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc, d_area);
|
|
|
checkKernelErrors;
|
|
|
return;
|
|
|
}
|
|
|
- MNN_PRINT("insideFormatConvert form %d to %d, not support\n", (int)srcDataFormat, (int)dstDataFormat);
|
|
|
-
|
|
|
-}
|
|
|
-
|
|
|
-void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN_DATA_FORMAT dstDataFormat, CUDARuntime* runtime, \
|
|
|
- const int area, const int batch, const int channel, const Tensor* srcTensor, int precision, bool srcDevice, bool dstDevice) {
|
|
|
-
|
|
|
- bool isFp16 = (precision == 2);
|
|
|
- bool isBf16 = (precision == 3);
|
|
|
- if(batch == 0 || area == 0 || channel == 0) {
|
|
|
- MNN_PRINT("Error: formatConvert size batch:%d - plane:%d - channel:%d, format:%d->%d, device:%d->%d\n", batch, area, channel, srcDataFormat, dstDataFormat, srcDevice, dstDevice);
|
|
|
- return;
|
|
|
- }
|
|
|
|
|
|
- auto des = TensorUtils::getDescribe(srcTensor);
|
|
|
- if ((des->quantAttr.get() != nullptr && des->type == DataType_DT_INT8) || srcTensor->getType().bits == 8) {
|
|
|
- if(srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
- if(!srcDevice && dstDevice) {
|
|
|
- const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
|
|
|
- const int block_num = runtime->blocks_num(maxCount);
|
|
|
- const int block_size = runtime->threads_num();
|
|
|
- C4NHW4_2_NHWC8<<<block_num, block_size>>>((int8_t *)input, (int8_t *)output,
|
|
|
- maxCount, batch, area, channel, UP_DIV(channel, 8) * 8);
|
|
|
- checkKernelErrors;
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- if(srcDevice && !dstDevice) {
|
|
|
- const int maxCount = batch * area * UP_DIV(channel, 4) * 4;
|
|
|
- const int block_num = runtime->blocks_num(maxCount);
|
|
|
- const int block_size = runtime->threads_num();
|
|
|
- NHWC8_2_C4NHW4<<<block_num, block_size>>>((int8_t *)input, (int8_t *)output,
|
|
|
- maxCount, batch, channel, area, UP_DIV(channel, 4) * 4);
|
|
|
- checkKernelErrors;
|
|
|
- return;
|
|
|
- }
|
|
|
+ // NC4HWC4 NHWC
|
|
|
+ if (srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NHWC) {
|
|
|
+ if (!srcDevice) {
|
|
|
+ const int maxCount = batch * area * channel;
|
|
|
+ const int block_num = runtime->blocks_num(maxCount);
|
|
|
+ const int block_size = runtime->threads_num();
|
|
|
+ C4NHW4_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 4) * 4, d_oc, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
+ } else {
|
|
|
+ const int maxCount = batch * area * channel;
|
|
|
+ const int block_num = runtime->blocks_num(maxCount);
|
|
|
+ const int block_size = runtime->threads_num();
|
|
|
+ NHWC8_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8, d_oc, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
}
|
|
|
-
|
|
|
- insideFormatConvert<int8_t, int8_t>((int8_t *)input, (int8_t *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- isFp16 = isFp16 & (halide_type_float == srcTensor->getType().code);
|
|
|
- isBf16 = isBf16 & (halide_type_float == srcTensor->getType().code);
|
|
|
- if(srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
- if(!srcDevice && dstDevice) {
|
|
|
+ // NCHW NC4HW4
|
|
|
+ if(srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
+ if (!dstDevice) {
|
|
|
+ const int maxCount = batch * area * UP_DIV(channel, 4) * 4;
|
|
|
+ const int block_num = runtime->blocks_num(maxCount);
|
|
|
+ const int block_size = runtime->threads_num();
|
|
|
+ NCHW_2_C4NHW4<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc4, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
+ } else {
|
|
|
const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
|
|
|
const int block_num = runtime->blocks_num(maxCount);
|
|
|
const int block_size = runtime->threads_num();
|
|
|
- if(isFp16) {
|
|
|
- C4NHW4_2_NHWC8<<<block_num, block_size>>>((float *)input, (half *)output,
|
|
|
- maxCount, batch, area, channel, UP_DIV(channel, 8) * 8);
|
|
|
- checkKernelErrors;
|
|
|
- } else if(isBf16) {
|
|
|
- #ifdef ENABLE_CUDA_BF16
|
|
|
- C4NHW4_2_NHWC8<<<block_num, block_size>>>((float *)input, (__nv_bfloat16 *)output,
|
|
|
- maxCount, batch, area, channel, UP_DIV(channel, 8) * 8);
|
|
|
- checkKernelErrors;
|
|
|
- #endif
|
|
|
- } else {
|
|
|
- C4NHW4_2_NHWC8<<<block_num, block_size>>>((float *)input, (float *)output,
|
|
|
- maxCount, batch, area, channel, UP_DIV(channel, 8) * 8);
|
|
|
- checkKernelErrors;
|
|
|
- }
|
|
|
- return;
|
|
|
+ NCHW_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc8, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
}
|
|
|
+ return;
|
|
|
+ }
|
|
|
|
|
|
- if(srcDevice && !dstDevice) {
|
|
|
+ // NHWC NC4HW4
|
|
|
+ if(srcDataFormat == MNN_DATA_FORMAT_NHWC && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
+ if (!dstDevice) {
|
|
|
const int maxCount = batch * area * UP_DIV(channel, 4) * 4;
|
|
|
const int block_num = runtime->blocks_num(maxCount);
|
|
|
const int block_size = runtime->threads_num();
|
|
|
- if(isFp16) {
|
|
|
- NHWC8_2_C4NHW4<<<block_num, block_size>>>((half *)input, (float *)output,
|
|
|
- maxCount, batch, channel, area, UP_DIV(channel, 4) * 4);
|
|
|
- checkKernelErrors;
|
|
|
- } else if(isBf16) {
|
|
|
- #ifdef ENABLE_CUDA_BF16
|
|
|
- NHWC8_2_C4NHW4<<<block_num, block_size>>>((__nv_bfloat16 *)input, (float *)output,
|
|
|
- maxCount, batch, channel, area, UP_DIV(channel, 4) * 4);
|
|
|
- checkKernelErrors;
|
|
|
- #endif
|
|
|
- } else {
|
|
|
- NHWC8_2_C4NHW4<<<block_num, block_size>>>((float *)input, (float *)output,
|
|
|
- maxCount, batch, channel, area, UP_DIV(channel, 4) * 4);
|
|
|
- checkKernelErrors;
|
|
|
- }
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- if(srcDevice && dstDevice) {
|
|
|
+ NHWC_2_C4NHW4<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc4, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
+ } else {
|
|
|
const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
|
|
|
const int block_num = runtime->blocks_num(maxCount);
|
|
|
const int block_size = runtime->threads_num();
|
|
|
- if(isFp16 || isBf16) {
|
|
|
- NCHW_2_NCHW<half, half><<<block_num, block_size>>>((half *)input, (half *)output, maxCount);
|
|
|
- checkKernelErrors;
|
|
|
- } else {
|
|
|
- NCHW_2_NCHW<float, float><<<block_num, block_size>>>((float *)input, (float *)output, maxCount);
|
|
|
- checkKernelErrors;
|
|
|
- }
|
|
|
- return;
|
|
|
+ NHWC_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc8, d_area);
|
|
|
+ checkKernelErrors;
|
|
|
}
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ MNN_ERROR("CUDA backend doesn't support the format conversion.\n");
|
|
|
+ MNN_ASSERT(false);
|
|
|
+ return;
|
|
|
+}
|
|
|
+
|
|
|
+void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN_DATA_FORMAT dstDataFormat, CUDARuntime* runtime, \
|
|
|
+ const int area, const int batch, const int channel, const Tensor* srcTensor, int precision, bool srcDevice, bool dstDevice) {
|
|
|
+ if(batch == 0 || area == 0 || channel == 0) {
|
|
|
+ MNN_PRINT("Error: formatConvert size batch:%d - plane:%d - channel:%d, format:%d->%d, device:%d->%d\n", batch, area, channel, srcDataFormat, dstDataFormat, srcDevice, dstDevice);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ bool isFp16 = (precision == 2) && (halide_type_float == srcTensor->getType().code);
|
|
|
+ bool isBf16 = (precision == 3) && (halide_type_float == srcTensor->getType().code);
|
|
|
+
|
|
|
+ // int8 case
|
|
|
+ auto des = TensorUtils::getDescribe(srcTensor);
|
|
|
+ if ((des->quantAttr.get() != nullptr && des->type == DataType_DT_INT8) || srcTensor->getType().bits == 8) {
|
|
|
+ insideFormatConvert<int8_t, int8_t>((int8_t *)input, (int8_t *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
+ return;
|
|
|
}
|
|
|
|
|
|
+ // FP case
|
|
|
if(!srcDevice) {
|
|
|
if(isFp16) {
|
|
|
- insideFormatConvert<float, half>((float *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<float, half>((float *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
} else if(isBf16) {
|
|
|
#ifdef ENABLE_CUDA_BF16
|
|
|
- insideFormatConvert<float, __nv_bfloat16>((float *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<float, __nv_bfloat16>((float *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
#endif
|
|
|
} else {
|
|
|
- insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
}
|
|
|
} else if(!dstDevice) {
|
|
|
if(isFp16) {
|
|
|
- insideFormatConvert<half, float>((half *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<half, float>((half *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
} else if(isBf16) {
|
|
|
#ifdef ENABLE_CUDA_BF16
|
|
|
- insideFormatConvert<__nv_bfloat16, float>((__nv_bfloat16 *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<__nv_bfloat16, float>((__nv_bfloat16 *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
#endif
|
|
|
} else {
|
|
|
- insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
}
|
|
|
} else {
|
|
|
if(isFp16) {
|
|
|
- insideFormatConvert<half, half>((half *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<half, half>((half *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
} else if(isBf16) {
|
|
|
#ifdef ENABLE_CUDA_BF16
|
|
|
- insideFormatConvert<__nv_bfloat16, __nv_bfloat16>((__nv_bfloat16 *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<__nv_bfloat16, __nv_bfloat16>((__nv_bfloat16 *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
#endif
|
|
|
} else {
|
|
|
- insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
|
|
|
+ insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
|
|
|
}
|
|
|
}
|
|
|
+ return;
|
|
|
}
|
|
|
|
|
|
|