目录
1. 引言2. 模型结构3. 计算模型的 FLOPs3.1. tensorflow 1.12.03.2. tensorflow 2.3.13.3. pytorch 1.10.1+cu1023.4. 结果对比4. 总结本文主要讨论如何计算 tensorflow 和 pytorch 模型的 FLOPs。如有表述不当之处欢迎批评指正。欢迎任何形式的转载,但请务必注明出处。
1. 引言
FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs。
2. 模型结构
为了说明方便,先搭建一个简单的神经网络模型,其模型结构以及主要参数如表1 所示。
(资料图)
表 1 模型结构及主要参数
Layers | channels | Kernels | Strides | Units | Activation |
---|---|---|---|---|---|
Conv2D | 32 | (4,4) | (1,2) | \ | relu |
GRU | \ | \ | \ | 96 | \ |
Dense | \ | \ | \ | 256 | sigmoid |
用 tensorflow(实际使用 tensorflow 中的 keras 模块)实现该模型的代码为:
from tensorflow.keras.layers import * from tensorflow.keras.models import load_model, Model def test_model_tf(Input_shape): # shape: [B, C, T, F] main_input = Input(batch_shape=Input_shape, name="main_inputs") conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation="relu", data_format="channels_first", name="conv")(main_input) # shape: [B, T, FC] gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv) gru = GRU(units=96, reset_after=True, return_sequences=True, name="gru")(gru) output = Dense(256, activation="sigmoid", name="output")(gru) model = Model(inputs=[main_input], outputs=[output]) return model
用 pytorch 实现该模型的代码为:
import torch import torch.nn as nn class test_model_torch(nn.Module): def __init__(self): super(test_model_torch, self).__init__() self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2)) self.relu = nn.ReLU() self.gru = nn.GRU(input_size=4064, hidden_size=96) self.fc = nn.Linear(96, 256) self.sigmoid = nn.Sigmoid() def forward(self, inputs): # shape: [B, C, T, F] out = self.conv2d(inputs) out = self.relu(out) # shape: [B, T, FC] batch, channel, frame, freq = out.size() out = torch.reshape(out, (batch, frame, freq*channel)) out, _ = self.gru(out) out = self.fc(out) out = self.sigmoid(out) return out
3. 计算模型的 FLOPs
本节讨论的版本具体为:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。
3.1. tensorflow 1.12.0
在 tensorflow 1.12.0 环境中,可以使用以下代码计算模型的 FLOPs:
import tensorflow as tf import tensorflow.keras.backend as K def get_flops(model): run_meta = tf.RunMetadata() opts = tf.profiler.ProfileOptionBuilder.float_operation() flops = tf.profiler.profile(graph=K.get_session().graph, run_meta=run_meta, cmd="op", options=opts) return flops.total_float_ops if __name__ == "__main__": x = K.random_normal(shape=(1, 1, 100, 256)) model = test_model_tf(x.shape) print("FLOPs of tensorflow 1.12.0:", get_flops(model))
3.2. tensorflow 2.3.1
在 tensorflow 2.3.1 环境中,可以使用以下代码计算模型的 FLOPs :
import tensorflow.compat.v1 as tf import tensorflow.compat.v1.keras.backend as K tf.disable_eager_execution() def get_flops(model): run_meta = tf.RunMetadata() opts = tf.profiler.ProfileOptionBuilder.float_operation() flops = tf.profiler.profile(graph=K.get_session().graph, run_meta=run_meta, cmd="op", options=opts) return flops.total_float_ops if __name__ == "__main__": x = K.random_normal(shape=(1, 1, 100, 256)) model = test_model_tf(x.shape) print("FLOPs of tensorflow 2.3.1:", get_flops(model))
3.3. pytorch 1.10.1+cu102
在 pytorch 1.10.1+cu102 环境中,可以使用以下代码计算模型的 FLOPs(需要安装 thop):
import thop x = torch.randn(1, 1, 100, 256) model = test_model_torch() flops, _ = thop.profile(model, inputs=(x,)) print("FLOPs of pytorch 1.10.1:", flops * 2)
需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代码有乘 2 2 2 操作。
3.4. 结果对比
三者计算出的 FLOPs 分别为:
tensorflow 1.12.0:
tensorflow 2.3.1:
pytorch 1.10.1:
可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的结果基本在同一个量级,而与 pytorch 1.10.1 计算出来的相差甚远。但如果将上述模型结构改为只包含第一层 Conv2D,三者计算出来的 FLOPs 却又是一致的。所以推断差异主要来自于 GRU 的 FLOPs。如读者知道其中详情,还请不吝赐教。
4. 总结
本文给出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算模型 FLOPs 的方法,但从本文所使用的测试模型来看, tensorflow 与 pytorch 统计出的结果相差甚远。当然,也可以根据网络层的类型及其对应的参数,推导计算出每个网络层所需的 FLOPs。
到此这篇关于计算 tensorflow 和 pytorch 模型的浮点运算数的文章就介绍到这了,更多相关tensorflow 和 pytorch浮点运算数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
X 关闭
X 关闭
- 15G资费不大降!三大运营商谁提供的5G网速最快?中国信通院给出答案
- 2联想拯救者Y70发布最新预告:售价2970元起 迄今最便宜的骁龙8+旗舰
- 3亚马逊开始大规模推广掌纹支付技术 顾客可使用“挥手付”结账
- 4现代和起亚上半年出口20万辆新能源汽车同比增长30.6%
- 5如何让居民5分钟使用到各种设施?沙特“线性城市”来了
- 6AMD实现连续8个季度的增长 季度营收首次突破60亿美元利润更是翻倍
- 7转转集团发布2022年二季度手机行情报告:二手市场“飘香”
- 8充电宝100Wh等于多少毫安?铁路旅客禁止、限制携带和托运物品目录
- 9好消息!京东与腾讯续签三年战略合作协议 加强技术创新与供应链服务
- 10名创优品拟通过香港IPO全球发售4100万股 全球发售所得款项有什么用处?