diff --git a/torchsummary/__init__.py b/torchsummary/__init__.py index fa52a70..8542ecf 100644 --- a/torchsummary/__init__.py +++ b/torchsummary/__init__.py @@ -1 +1 @@ -from .torchsummary import summary \ No newline at end of file +from .torchsummary import summary, summary_depth \ No newline at end of file diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..c8cd3ec 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -5,6 +5,133 @@ from collections import OrderedDict import numpy as np +def output(summary, keys, left, right, output_depth, depth=1): + if depth > output_depth: + return + + nl = left - 1 + + for i in range(left, right): + layer = keys[i] + if summary[layer]['depth'] == depth: + if depth == 1: + start = "├─" + layer + else: + start = "| " * (depth - 1) + "└─" + layer + + new_line = "{:<40} {:<25} {:<15}".format( + start, + str(summary[layer]["output_shape"]), + "--" if summary[layer]["nb_params"] == 0 else "{0:,}".format(summary[layer]["nb_params"]) + ) + print(new_line) + + output(summary, keys, nl+1, i, output_depth, depth+1) + nl = i + +def apply(model, fn, depth=0): + fn(model, depth) + for module in model.children(): + apply(module, fn, depth+1) + +def summary_depth(model, input_size, batch_size=-1, device="cuda", output_depth=3): + + def register_hook(module, depth): + + def hook(module, input, output): + class_name = str(module.__class__).split(".")[-1].split("'")[0] + idx[depth] = idx.get(depth, 0) + 1 + m_key = "%s: %i-%i" % (class_name, depth, idx[depth]) + summary[m_key] = OrderedDict() + summary[m_key]["input_shape"] = list(input[0].size()) + summary[m_key]["input_shape"][0] = batch_size + if isinstance(output, (list, tuple)): + summary[m_key]["output_shape"] = [ + [-1] + list(o.size())[1:] for o in output + ] + else: + summary[m_key]["output_shape"] = list(output.size()) + summary[m_key]["output_shape"][0] = batch_size + + params = 0 + if hasattr(module, "weight") and hasattr(module.weight, "size"): + params += torch.prod(torch.LongTensor(list(module.weight.size()))) + summary[m_key]["trainable"] = module.weight.requires_grad + if hasattr(module, "bias") and hasattr(module.bias, "size"): + params += torch.prod(torch.LongTensor(list(module.bias.size()))) + summary[m_key]["nb_params"] = params + summary[m_key]['depth'] = depth + + if module != model: + hooks.append(module.register_forward_hook(hook)) + + device = device.lower() + assert device in [ + "cuda", + "cpu", + ], "Input device is not valid, please specify 'cuda' or 'cpu'" + + if device == "cuda" and torch.cuda.is_available(): + dtype = torch.cuda.FloatTensor + else: + dtype = torch.FloatTensor + + # multiple inputs to the network + if isinstance(input_size, tuple): + input_size = [input_size] + + # batch_size of 2 for batchnorm + x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] + # print(type(x[0])) + + # create properties + summary = OrderedDict() + idx = {} + + hooks = [] + + # register hook + apply(model, register_hook) + + # make a forward pass + # print(x.shape) + model(*x) + + # remove these hooks + for h in hooks: + h.remove() + + keys = list(summary.keys()) + print("-" * 90) + line_new = "{:<40} {:<25} {:<15}".format("Layer (type:depth-idx)", "Output Shape", "Param #") + print(line_new) + print("=" * 90) + output(summary, keys, 0, len(keys), output_depth) + + total_params = 0 + total_output = 0 + trainable_params = 0 + for layer in summary: + total_params += summary[layer]["nb_params"] + total_output += np.prod(summary[layer]["output_shape"]) + if "trainable" in summary[layer]: + if summary[layer]["trainable"] == True: + trainable_params += summary[layer]["nb_params"] + + # assume 4 bytes/number (float on cuda). + total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) + total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) + + print("=" * 90) + print("Total params: {0:,}".format(total_params)) + print("Trainable params: {0:,}".format(trainable_params)) + print("Non-trainable params: {0:,}".format(total_params - trainable_params)) + print("-" * 90) + print("Input size (MB): %0.2f" % total_input_size) + print("Params size (MB): %0.2f" % total_params_size) + print("-" * 90) + # return summary + def summary(model, input_size, batch_size=-1, device="cuda"): @@ -112,4 +239,4 @@ def hook(module, input, output): print("Params size (MB): %0.2f" % total_params_size) print("Estimated Total Size (MB): %0.2f" % total_size) print("----------------------------------------------------------------") - # return summary + # return summary \ No newline at end of file