From 1a3845b9e7a9accc2fce8902d5b27efd40b71fda Mon Sep 17 00:00:00 2001 From: sangyunxin Date: Wed, 17 Jul 2019 10:29:06 +0800 Subject: [PATCH 1/3] add `summary_depth` --- torchsummary/__init__.py | 2 +- torchsummary/torchsummary.py | 126 ++++++++++++++++++++++++++++++++++- 2 files changed, 126 insertions(+), 2 deletions(-) diff --git a/torchsummary/__init__.py b/torchsummary/__init__.py index fa52a70..8542ecf 100644 --- a/torchsummary/__init__.py +++ b/torchsummary/__init__.py @@ -1 +1 @@ -from .torchsummary import summary \ No newline at end of file +from .torchsummary import summary, summary_depth \ No newline at end of file diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..66b04f0 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -5,6 +5,130 @@ from collections import OrderedDict import numpy as np +def output(summary, keys, left, right, depth=1): + nl = left - 1 + + for i in range(left, right): + layer = keys[i] + if summary[layer]['depth'] == depth: + if depth == 1: + start = "├─" + layer + else: + start = "| " * (depth - 1) + "└─" + layer + + new_line = "{:<40} {:<25} {:<15}".format( + start, + str(summary[layer]["output_shape"]), + "{0:,}".format(summary[layer]["nb_params"]) + ) + print(new_line) + + output(summary, keys, nl+1, i, depth+1) + nl = i + +def apply(model, fn, depth=0): + fn(model, depth) + for module in model.children(): + apply(module, fn, depth+1) + +def summary_depth(model, input_size, batch_size=-1, device="cuda", output_depth=3): + + def register_hook(module, depth): + + def hook(module, input, output): + class_name = str(module.__class__).split(".")[-1].split("'")[0] + idx[depth] = idx.get(depth, 0) + 1 + m_key = "%s: %i-%i" % (class_name, depth, idx[depth]) + summary[m_key] = OrderedDict() + summary[m_key]["input_shape"] = list(input[0].size()) + summary[m_key]["input_shape"][0] = batch_size + if isinstance(output, (list, tuple)): + summary[m_key]["output_shape"] = [ + [-1] + list(o.size())[1:] for o in output + ] + else: + summary[m_key]["output_shape"] = list(output.size()) + summary[m_key]["output_shape"][0] = batch_size + + params = 0 + if hasattr(module, "weight") and hasattr(module.weight, "size"): + params += torch.prod(torch.LongTensor(list(module.weight.size()))) + summary[m_key]["trainable"] = module.weight.requires_grad + if hasattr(module, "bias") and hasattr(module.bias, "size"): + params += torch.prod(torch.LongTensor(list(module.bias.size()))) + summary[m_key]["nb_params"] = params + summary[m_key]['depth'] = depth + + if module != model: + hooks.append(module.register_forward_hook(hook)) + + device = device.lower() + assert device in [ + "cuda", + "cpu", + ], "Input device is not valid, please specify 'cuda' or 'cpu'" + + if device == "cuda" and torch.cuda.is_available(): + dtype = torch.cuda.FloatTensor + else: + dtype = torch.FloatTensor + + # multiple inputs to the network + if isinstance(input_size, tuple): + input_size = [input_size] + + # batch_size of 2 for batchnorm + x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] + # print(type(x[0])) + + # create properties + summary = OrderedDict() + idx = {} + + hooks = [] + + # register hook + apply(model, register_hook) + + # make a forward pass + # print(x.shape) + model(*x) + + # remove these hooks + for h in hooks: + h.remove() + + keys = list(summary.keys()) + print("----------------------------------------------------------------------------") + line_new = "{:<40} {:<25} {:<15}".format("Layer (type:depth-idx)", "Output Shape", "Param #") + print(line_new) + print("============================================================================") + output(summary, keys, 0, len(keys)) + + total_params = 0 + total_output = 0 + trainable_params = 0 + for layer in summary: + total_params += summary[layer]["nb_params"] + total_output += np.prod(summary[layer]["output_shape"]) + if "trainable" in summary[layer]: + if summary[layer]["trainable"] == True: + trainable_params += summary[layer]["nb_params"] + + # assume 4 bytes/number (float on cuda). + total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) + total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) + + print("============================================================================") + print("Total params: {0:,}".format(total_params)) + print("Trainable params: {0:,}".format(trainable_params)) + print("Non-trainable params: {0:,}".format(total_params - trainable_params)) + print("----------------------------------------------------------------------------") + print("Input size (MB): %0.2f" % total_input_size) + print("Params size (MB): %0.2f" % total_params_size) + print("----------------------------------------------------------------------------") + # return summary + def summary(model, input_size, batch_size=-1, device="cuda"): @@ -112,4 +236,4 @@ def hook(module, input, output): print("Params size (MB): %0.2f" % total_params_size) print("Estimated Total Size (MB): %0.2f" % total_size) print("----------------------------------------------------------------") - # return summary + # return summary \ No newline at end of file From e3b29ce7e18fccd0776ad3658e6805da749fd51c Mon Sep 17 00:00:00 2001 From: sangyunxin Date: Wed, 17 Jul 2019 10:52:13 +0800 Subject: [PATCH 2/3] add `summary_depth` --- torchsummary/torchsummary.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 66b04f0..f080442 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -5,7 +5,10 @@ from collections import OrderedDict import numpy as np -def output(summary, keys, left, right, depth=1): +def output(summary, keys, left, right, output_depth, depth=1): + if depth > output_depth: + return + nl = left - 1 for i in range(left, right): @@ -23,7 +26,7 @@ def output(summary, keys, left, right, depth=1): ) print(new_line) - output(summary, keys, nl+1, i, depth+1) + output(summary, keys, nl+1, i, output_depth, depth+1) nl = i def apply(model, fn, depth=0): @@ -99,11 +102,11 @@ def hook(module, input, output): h.remove() keys = list(summary.keys()) - print("----------------------------------------------------------------------------") + print("-" * 90) line_new = "{:<40} {:<25} {:<15}".format("Layer (type:depth-idx)", "Output Shape", "Param #") print(line_new) - print("============================================================================") - output(summary, keys, 0, len(keys)) + print("=" * 90) + output(summary, keys, 0, len(keys), output_depth) total_params = 0 total_output = 0 @@ -119,14 +122,14 @@ def hook(module, input, output): total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) - print("============================================================================") + print("=" * 90) print("Total params: {0:,}".format(total_params)) print("Trainable params: {0:,}".format(trainable_params)) print("Non-trainable params: {0:,}".format(total_params - trainable_params)) - print("----------------------------------------------------------------------------") + print("-" * 90) print("Input size (MB): %0.2f" % total_input_size) print("Params size (MB): %0.2f" % total_params_size) - print("----------------------------------------------------------------------------") + print("-" * 90) # return summary From 80af31afc3376b496090410ea1b2d4c8b2618831 Mon Sep 17 00:00:00 2001 From: sangyunxin Date: Wed, 17 Jul 2019 11:16:36 +0800 Subject: [PATCH 3/3] display '--' when param equals 0 --- torchsummary/torchsummary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index f080442..c8cd3ec 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -22,7 +22,7 @@ def output(summary, keys, left, right, output_depth, depth=1): new_line = "{:<40} {:<25} {:<15}".format( start, str(summary[layer]["output_shape"]), - "{0:,}".format(summary[layer]["nb_params"]) + "--" if summary[layer]["nb_params"] == 0 else "{0:,}".format(summary[layer]["nb_params"]) ) print(new_line)