Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Improve the supernode.py script
[simgrid.git] / examples / platforms / supernode.py
1 #! /usr/bin/env python3
2
3 # Copyright (c) 2006-2022. The SimGrid Team. All rights reserved.
4
5 # This program is free software; you can redistribute it and/or modify it
6 # under the terms of the license (GNU LGPL) which comes with this package.
7
8
9 # This script takes as input a C++ platform file, compiles it, then dumps the
10 # routing graph as a CSV and generates an image.
11 # The layout should be alright for any platform file, but the colors are very
12 # ad-hoc for file supernode.cpp : do not hesitate to adapt this script to your needs.
13 # An option is provided to "simplify" the graph by removing the link vertices. It assumes that these vertices have
14 # "link" in their name.
15
16 import sys
17 import subprocess
18 import pandas
19 import matplotlib as mpl
20 import matplotlib.pyplot as plt
21 import networkx as nx
22 import argparse
23 import tempfile
24 import os
25
26 try:
27     from palettable.colorbrewer.qualitative import Set1_9
28     colors = Set1_9.hex_colors
29 except ImportError:
30     print('Warning: could not import palettable, will use a default palette.')
31     colors = [None]*10
32
33
34 def run_command(cmd):
35     print(cmd)
36     proc = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
37     stdout, stderr = proc.communicate()
38     if proc.returncode != 0:
39         sys.exit(f'Command failed:\n{stderr.decode()}')
40
41
42 def compile_platform(platform_cpp, platform_so):
43     cmd = f'g++ -g -fPIC -shared -o {platform_so} {platform_cpp} -lsimgrid'
44     run_command(cmd)
45
46
47 def dump_csv(platform_so, platform_csv):
48     cmd = f'graphicator {platform_so} {platform_csv}'
49     run_command(cmd)
50
51
52 def merge_updown(graph):
53     '''
54     Merge all the UP and DOWN links.
55     '''
56     H = graph.copy()
57     downlinks = [v for v in graph if 'DOWN' in v]
58     mapping = {}
59     for down in downlinks:
60         up = down.replace('DOWN', 'UP')
61         H = nx.contracted_nodes(H, down, up)
62         mapping[down] = down.replace('_DOWN', '')
63     return nx.relabel_nodes(H, mapping)
64
65
66 def contract_links(graph):
67     '''
68     Remove all the 'link' vertices from the graph to directly connect the nodes.
69     Note: it assumes that link vertices have the "link" string in their name.
70     '''
71     H = graph.copy()
72     links = [v for v in graph if 'link' in v]
73     new_edges = []
74     for v in links:
75         neigh = [u for u in graph.neighbors(v) if 'link' not in u]  # with Floyd zones, we have links connected to links
76         assert len(neigh) == 2
77         new_edges.append(neigh)
78     # Adding edges from graph that have no links
79     for u, v in graph.edges:
80         if 'link' not in u and 'link' not in v:
81             new_edges.append((u, v))
82     return nx.from_edgelist(new_edges)
83
84
85 def load_graph(platform_csv, simplify_graph):
86     edges = pandas.read_csv(platform_csv)
87     graph = nx.Graph()
88     graph.add_edges_from([e for _, e in edges.drop_duplicates().iterrows()])
89     print(f'Loaded a graph with {len(graph)} vertices with {len(graph.edges)} edges')
90     if simplify_graph:
91         graph = contract_links(merge_updown(graph))
92         print(f'Simplified the graph, it now has {len(graph)} vertices with {len(graph.edges)} edges')
93     return graph
94
95
96 def plot_graph(graph, label=False, groups=[]):
97     # First, we compute the graph layout, i.e. the position of the nodes.
98     # The neato algorithm from graphviz is nicer, but this is an extra-dependency.
99     # The spring_layout is also not too bad.
100     try:
101         pos = nx.nx_agraph.graphviz_layout(graph, 'neato')
102     except ImportError:
103         print('Warning: could not import pygraphviz, will use another layout algorithm.')
104         pos = nx.spring_layout(graph, k=0.5, iterations=1000, seed=42)
105     plt.figure(figsize=(20, 15))
106     plt.axis('off')
107     all_nodes = set(graph)
108     # We then iterate on all the specified groups, to plot each of them in the right color.
109     # Note that the order of the groups is important here, as we are looking at substrings in the node names.
110     for i, grp in enumerate(groups):
111         nodes = {u for u in all_nodes if grp in u}
112         all_nodes -= nodes
113         nx.draw_networkx_nodes(graph, pos, nodelist=nodes, node_size=50, node_color=colors[i], label=grp.replace('_', ''))
114     nx.draw_networkx_nodes(graph, pos, nodelist=all_nodes, node_size=50, node_color=colors[-1], label='unknown')
115     # Finally we draw the edges, the (optional) labels, and the legend.
116     nx.draw_networkx_edges(graph, pos, alpha=0.3)
117     if label:
118         nx.draw_networkx_labels(graph, pos)
119     plt.legend(scatterpoints = 1)
120
121
122 def generate_svg(platform_csv, output_file, simplify_graph):
123     graph = load_graph(platform_csv, simplify_graph)
124     plot_graph(graph, label=False, groups=['router', 'link', 'cpu', '_node', 'supernode', 'cluster'])
125     plt.savefig(output_file)
126     print(f'Generated file {output_file}')
127
128
129 if __name__ == '__main__':
130     parser = argparse.ArgumentParser(description='Visualization of topologies for SimGrid C++ platforms')
131     parser.add_argument('input', type=str, help='SimGrid C++ platform file name (input)')
132     parser.add_argument('output', type=str, help='File name of the output image')
133     parser.add_argument('--simplify', action='store_true', help='Simplify the topology (removing link vertices)')
134     args = parser.parse_args()
135     if not args.input.endswith('.cpp'):
136         parser.error(f'SimGrid platform must be a C++ file (with .cpp extension), got {args.input}')
137     if not os.path.isfile(args.input):
138         parser.error(f'File {args.input} not found')
139     output_dir = os.path.dirname(args.output)
140     if output_dir != '' and not os.path.isdir(output_dir):
141         parser.error(f'Not a directory: {output_dir}')
142     with tempfile.TemporaryDirectory() as tmpdirname:
143         platform_cpp = args.input
144         platform_csv = os.path.join(tmpdirname, 'platform.csv')
145         platform_so = os.path.join(tmpdirname, 'platform.so')
146         compile_platform(platform_cpp, platform_so)
147         dump_csv(platform_so, platform_csv)
148         generate_svg(platform_csv, args.output, args.simplify)