Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Cosmetics: removal of a useless comment I wrote in the previous commits
[simgrid.git] / contrib / network_model / regression2.py
1 #!/usr/bin/env python
2
3 # Copyright (c) 2011, 2014. The SimGrid Team.
4 # All rights reserved.
5
6 # This program is free software; you can redistribute it and/or modify it
7 # under the terms of the license (GNU LGPL) which comes with this package.
8
9 # This script takes the following command line parameters
10 # 1) an input file containing 2 columns: message size and 1-way trip time
11 # 2) the maximum relative error for a line segment
12 # 3) the minimum number of points needed to justify adding a line segment
13 # 4) the number of links
14 # 5) the latency
15 # 6) the bandwidth
16
17 import sys
18
19 def compute_regression(points):
20     N = len(points)
21
22     if N < 1:
23         return None
24
25     if N < 2:
26         return (0, points[0][1])
27
28     Sx = Sy = Sxx = Syy = Sxy = 0.0
29
30     for x, y in points:
31         Sx  += x
32         Sy  += y
33         Sxx += x*x
34         Syy += y*y
35         Sxy += x*y
36     denom = Sxx * N - Sx * Sx
37     # don't return 0 or negative values as a matter of principle...
38     m = max(sys.float_info.min, (Sxy * N - Sy * Sx) / denom)
39     b = max(sys.float_info.min, (Sxx * Sy - Sx * Sxy) / denom)
40     return (m, b)
41
42 def compute_error(m, b, x, y):
43     yp = m*x+b
44     return abs(yp - y) / max(min(yp, y), sys.float_info.min)
45
46 def compute_max_error(m, b, points):
47     max_error = 0.0
48     for x, y in points:
49         max_error = max(max_error, compute_error(m, b, x, y))
50     return max_error
51
52 def get_max_error_point(m, b, points):
53     max_error_index = -1
54     max_error = 0.0
55
56     i = 0
57     while i < len(points):
58         x, y = points[i]
59         error = compute_error(m, b, x, y)
60         if error > max_error:
61             max_error_index = i
62             max_error = error
63         i += 1
64
65     return (max_error_index, max_error)
66
67 infile_name = sys.argv[1]
68 error_bound = float(sys.argv[2])
69 min_seg_points = int(sys.argv[3])
70 links = int(sys.argv[4])
71 latency = float(sys.argv[5])
72 bandwidth = float(sys.argv[6])
73
74 infile = open(infile_name, 'r')
75
76 # read datafile
77 points = []
78 for line in infile:
79     fields = line.split()
80     points.append((int(fields[0]), int(fields[1])))
81 infile.close()
82
83 # should sort points by x values
84 points.sort()
85
86 # break points up into segments
87 pointsets = []
88 lbi = 0
89 while lbi < len(points):
90     min_ubi = lbi
91     max_ubi = len(points) - 1
92     while max_ubi - min_ubi > 1:
93         ubi = (min_ubi + max_ubi) / 2
94         m, b = compute_regression(points[lbi:ubi+1])
95         max_error = compute_max_error(m, b, points[lbi:ubi+1])
96         if max_error > error_bound:
97             max_ubi = ubi - 1
98         else:
99             min_ubi = ubi
100     ubi = max_ubi
101     if min_ubi < max_ubi:
102         m, b = compute_regression(points[lbi:max_ubi+1])
103         max_error = compute_max_error(m, b, points[lbi:max_ubi+1])
104         if max_error > error_bound:
105             ubi = min_ubi
106     pointsets.append(points[lbi:ubi+1])
107     lbi = ubi+1
108
109 # try to merge larger segments if possible and compute piecewise regression
110 i = 0
111 segments = []
112 notoutliers = 0
113 while i < len(pointsets):
114     currpointset = []
115     j = i
116     while j < len(pointsets):
117         newpointset = currpointset + pointsets[j] 
118         # if joining a small segment, we can delete bad points
119         if len(pointsets[j]) < min_seg_points:
120             k = 0
121             while k < len(pointsets[j]):
122                 m, b = compute_regression(newpointset)
123                 max_error_index, max_error = get_max_error_point(m, b, newpointset)
124                 if max_error <= error_bound:
125                     break
126                 del newpointset[max_error_index]
127                 k += 1
128             # only add new pointset if we had to delete fewer than its length
129             # points
130             if k < len(pointsets[j]):
131                 i = j
132                 currpointset = newpointset   
133         # otherwise, we just see if it works...
134         else:
135             m, b = compute_regression(newpointset)
136             max_error = compute_max_error(m, b, newpointset)
137             if max_error > error_bound:
138                 break
139             i = j
140             currpointset = newpointset   
141         j += 1
142     i += 1
143     # outliers are ignored when constructing the piecewise funciton
144     if len(currpointset) < min_seg_points:
145         continue
146     notoutliers += len(currpointset)
147     m, b = compute_regression(currpointset)
148     lb = min(x for x, y in currpointset)
149     lat_factor = b / (1.0e6 * links * latency)
150     bw_factor = 1.0e6 / (m * bandwidth)
151     segments.append((lb, m, b, lat_factor, bw_factor))
152
153 outliers = len(points) - notoutliers
154 segments.sort()
155 segments.reverse()
156
157 print "/**--------- <copy/paste C code snippet in surf/network.c> -------------"
158 print "  * produced by:"
159 print "  *", " ".join(sys.argv)
160 print "  * outliers:", outliers
161 print "  * gnuplot: "
162 print "    plot \"%s\" using 1:2 with lines title \"data\", \\" % (infile_name)
163 for lb, m, b, lat_factor, bw_factor in segments:
164     print "        (x >= %d) ? %g*x+%g : \\" % (lb, m, b)
165 print "        1.0 with lines title \"piecewise function\""
166 print "  *-------------------------------------------------------------------*/"
167 print
168 print "static double smpi_bandwidth_factor(double size)\n{\n"
169 for lb, m, b, lat_factor, bw_factor in segments:
170     print "    if (size >= %d) return %g;" % (lb, bw_factor)
171 print "    return 1.0;\n}\n"
172 print "static double smpi_latency_factor(double size)\n{\n"
173 for lb, m, b, lat_factor, bw_factor in segments:
174     print "    if (size >= %d) return %g;" % (lb, lat_factor)
175 print "    return 1.0;\n}\n"
176 print "/**--------- <copy/paste C code snippet in surf/network.c> -----------*/"