1
1
2
- local ffi = require ' ffi'
3
- local bit = require ' bit'
4
- local gsl = require ' gsl'
2
+ if jit .arch ~= ' x64' then
3
+ print (' WARNING: please use BIT=64 for optimal OpenBLAS performance' )
4
+ end
5
+
6
+ local ffi = require ' ffi'
7
+ local bit = require ' bit'
8
+ local time = require ' time'
9
+ local alg = require ' sci.alg'
10
+ local prng = require ' sci.prng'
11
+ local stat = require ' sci.stat'
12
+ local dist = require ' sci.dist'
13
+ local complex = require ' sci.complex'
5
14
6
- local min , max , abs , sqrt , random , floor = math.min , math.max , math.abs , math. sqrt , math.random , math.floor
15
+ local min , sqrt , random , abs = math.min , math.sqrt , math.random , math.abs
7
16
local cabs = complex .abs
8
17
local rshift = bit .rshift
9
18
local format = string.format
19
+ local nowutc = time .nowutc
20
+ local rng = prng .std ()
21
+ local vec , mat , trace , join = alg .vec , alg .mat , alg .trace , alg .join
22
+ local var , mean = stat .var , stat .mean
10
23
11
- local gettime
12
- do
13
- ffi .cdef [[
14
- struct timeval {
15
- long tv_sec ;
16
- long tv_usec ;
17
- };
18
-
19
- int gettimeofday (struct timeval * tp , void * tzp );
20
- ]]
21
-
22
- local tv = ffi .new (' struct timeval[1]' )
23
-
24
- gettime = function ()
25
- ffi .C .gettimeofday (tv , nil )
26
- return tv [0 ].tv_sec , tv [0 ].tv_usec
27
- end
28
- end
29
-
30
- -- return the elapsed time in ms
24
+ ---- ----------------------------------------------------------------------------
31
25
local function elapsed (f )
32
- local s0 , us0 = gettime ()
33
- f ()
34
- local s1 , us1 = gettime ()
35
- return tonumber (s1 - s0 ) * 1000 + tonumber (us1 - us0 ) / 1000
36
- end
37
-
38
- local function timeit (f , name )
39
- local t = nil
40
- for k = 1 , 5 do
41
- local tx = elapsed (f )
42
- t = t and min (t , tx ) or tx
26
+ local t0 = nowutc ()
27
+ local val1 , val2 = f ()
28
+ local t1 = nowutc ()
29
+ return (t1 - t0 ):tomilliseconds (), val1 , val2
30
+ end
31
+
32
+ local function timeit (f , name , check )
33
+ local t , k , s = 1 / 0 , 0 , nowutc ()
34
+ while true do
35
+ k = k + 1
36
+ local tx , val1 , val2 = elapsed (f )
37
+ t = min (t , tx )
38
+ if check then
39
+ check (val1 , val2 )
40
+ end
41
+ if k > 5 and (nowutc () - s ):toseconds () >= 2 then break end
43
42
end
44
- print (format (" lua,%s,%g" , name , t ))
43
+ io.write (format (' lua,%s,%g\n ' , name , t ))
45
44
end
46
45
46
+ ---- ----------------------------------------------------------------------------
47
47
local function fib (n )
48
48
if n < 2 then
49
49
return n
@@ -52,23 +52,21 @@ local function fib(n)
52
52
end
53
53
end
54
54
55
- assert (fib (20 ) == 6765 )
56
- timeit (|| fib (20 ), " fib" )
55
+ timeit (function () return fib (20 ) end , ' fib' , function (x ) assert (x == 6765 ) end )
57
56
58
57
local function parseint ()
59
- local r = rng .new (' rand' )
60
58
local lmt = 2 ^ 32 - 1
61
59
local n , m
62
60
for i = 1 , 1000 do
63
- n = r : getint (lmt )
61
+ n = random (lmt ) -- Between 0 and 2^32 - 1, i.e. uint32_t.
64
62
local s = format (' 0x%x' , tonumber (n ))
65
63
m = tonumber (s )
66
64
end
67
- assert (m == n )
68
- return n
65
+ assert (n == m ) -- Done here to be even with Julia benchmark.
66
+ return n , m
69
67
end
70
68
71
- timeit (parseint , " parse_int" )
69
+ timeit (parseint , ' parse_int' )
72
70
73
71
local function mandel (z )
74
72
local c = z
@@ -81,28 +79,23 @@ local function mandel(z)
81
79
end
82
80
return maxiter
83
81
end
84
- function mandelperf ()
85
- local a , re , im , z
86
- a = ffi .new (" double[?]" , 546 )
87
- r = 0
82
+ local function mandelperf ()
83
+ local a = ffi .new (' double[?]' , 546 )
88
84
for r = - 20 , 5 do
89
- re = r * 0.1
85
+ local re = r * 0.1
90
86
for i =- 10 , 10 do
91
- im = i * 0.1
87
+ local im = i * 0.1
92
88
a [r * 21 + i + 430 ] = mandel (re + 1 i * im )
93
89
end
94
90
end
95
91
return a
96
92
end
97
93
98
- do
99
- local a = mandelperf ()
94
+ timeit (mandelperf , ' mandel' , function (a )
100
95
local sum = 0
101
96
for i = 0 , 545 do sum = sum + a [i ] end
102
97
assert (sum == 14791 )
103
- end
104
-
105
- timeit (mandelperf , " mandel" )
98
+ end )
106
99
107
100
local function qsort (a , lo , hi )
108
101
local i , j = lo , hi
@@ -124,12 +117,19 @@ end
124
117
125
118
local function sortperf ()
126
119
local n = 5000
127
- local r = rng .new (' rand' )
128
- local v = iter .ilist (|| r :get (), n )
129
- qsort (v , 1 , n )
120
+ local v = ffi .new (' double[?]' , n + 1 )
121
+ for i = 1 ,n do
122
+ v [i ] = rng :sample ()
123
+ end
124
+ return qsort (v , 1 , n )
130
125
end
131
126
132
-
127
+ timeit (sortperf , ' quicksort' , function (x )
128
+ for i = 2 ,5000 do
129
+ assert (x [i - 1 ] <= x [i ])
130
+ end
131
+ end
132
+ )
133
133
134
134
local function pisum ()
135
135
local sum
@@ -142,104 +142,59 @@ local function pisum()
142
142
return sum
143
143
end
144
144
145
- local function stat (v )
146
- local p , q = 0 , 0
147
- local n = # v
148
- for k = 1 , n do
149
- local x = v [k ]
150
- p = p + x
151
- q = q + x * x
152
- end
153
- return sqrt ((n * (n * q - p * p ))/ ((n - 1 )* p * p ))
154
- end
155
-
156
- local function randmatstat (t )
157
- local n = 5
158
- local A = iter .ilist (|| matrix .alloc (n , n ), 4 )
159
-
160
- local P = matrix .alloc (n , 4 * n )
161
- local Q = matrix .alloc (2 * n , 2 * n )
162
-
163
- local PtP1 = matrix .alloc (4 * n , 4 * n )
164
- local PtP2 = matrix .alloc (4 * n , 4 * n )
165
- local QtQ1 = matrix .alloc (2 * n , 2 * n )
166
- local QtQ2 = matrix .alloc (2 * n , 2 * n )
167
-
168
- local get , set = A [1 ].get , A [1 ].set
145
+ timeit (pisum , ' pi_sum' , function (x )
146
+ assert (abs (x - 1.644834071848065 ) < 1e-12 )
147
+ end )
169
148
170
- local r = rng .new (' rand' )
171
- local randn = || rnd .gaussian (r , 1 )
172
-
173
- local function hstackf (i , j )
174
- local k , r = math .divmod (j - 1 , n )
175
- return get (A [k + 1 ], i , r + 1 )
149
+ local function rand (r , c )
150
+ local x = mat (r , c )
151
+ for i = 1 ,# x do
152
+ x [i ] = rng :sample ()
176
153
end
154
+ return x
155
+ end
177
156
178
- local function vstackf ( i , j )
179
- local ik , ir = math . divmod ( i - 1 , n )
180
- local jk , jr = math . divmod ( j - 1 , n )
181
- return get ( A [ 2 * ik + jk + 1 ], ir + 1 , jr + 1 )
157
+ local function randn ( r , c )
158
+ local x = mat ( r , c )
159
+ for i = 1 , # x do
160
+ x [ i ] = dist . normal ( 0 , 1 ): sample ( rng )
182
161
end
162
+ return x
163
+ end
183
164
184
- local Tr , NT = gsl .CblasTrans , gsl .CblasNoTrans
185
-
186
- local v , w = {}, {}
187
-
188
- for i = 1 , t do
189
- matrix .fset (A [1 ], randn )
190
- matrix .fset (A [2 ], randn )
191
- matrix .fset (A [3 ], randn )
192
- matrix .fset (A [4 ], randn )
193
-
194
- matrix .fset (P , hstackf )
195
- matrix .fset (Q , vstackf )
196
-
197
- gsl .gsl_blas_dgemm (Tr , NT , 1.0 , P , P , 0.0 , PtP1 )
198
- gsl .gsl_blas_dgemm (NT , NT , 1.0 , PtP1 , PtP1 , 0.0 , PtP2 )
199
- gsl .gsl_blas_dgemm (NT , NT , 1.0 , PtP2 , PtP2 , 0.0 , PtP1 )
200
-
201
- local vi = 0
202
- for j = 1 , n do vi = vi + get (PtP1 , j , j ) end
203
- v [i ] = vi
204
-
205
- gsl .gsl_blas_dgemm (Tr , NT , 1.0 , Q , Q , 0.0 , QtQ1 )
206
- gsl .gsl_blas_dgemm (NT , NT , 1.0 , QtQ1 , QtQ1 , 0.0 , QtQ2 )
207
- gsl .gsl_blas_dgemm (NT , NT , 1.0 , QtQ2 , QtQ2 , 0.0 , QtQ1 )
208
-
209
- local wi = 0
210
- for j = 1 , 2 * n do wi = wi + get (QtQ1 , j , j ) end
211
- w [i ] = wi
165
+ local function randmatstat (t )
166
+ local n = 5
167
+ local v , w = vec (t ), vec (t )
168
+ for i = 1 ,t do
169
+ local a , b , c , d = randn (n , n ), randn (n , n ), randn (n , n ), randn (n , n )
170
+ local P = join (a .. b .. c .. d )
171
+ local Q = join (a .. b , c .. d )
172
+ v [i ] = trace ((P []` **P[])^^4)
173
+ w [i ] = trace ((Q []` **Q[])^^4)
212
174
end
213
-
214
- return stat (v ), stat (w )
175
+ return sqrt (var (v ))/ mean (v ), sqrt (var (w ))/ mean (w )
215
176
end
216
177
217
- do
218
- local s1 , s2 = randmatstat (1000 )
219
- assert ( 0.5 < s1 and s1 < 1.0
220
- and 0.5 < s2 and s2 < 1.0 )
221
- end
178
+ timeit (function () return randmatstat (1000 ) end , ' rand_mat_stat' ,
179
+ function (s1 , s2 )
180
+ assert ( 0.5 < s1 and s1 < 1.0 and 0.5 < s2 and s2 < 1.0 )
181
+ end )
222
182
223
183
local function randmatmult (n )
224
- local r = rng .new (' rand' )
225
- -- local rand = || r:get()
226
- local rand = random
227
- local a = matrix .new (n , n , rand )
228
- local b = matrix .new (n , n , rand )
229
- return a * b
230
- end
231
-
232
- local function printfd (n )
233
- local f = io.open (" /dev/null" ," w" )
234
- for i = 1 , n do
235
- f :write (format (" %d %d\n " , i , i + 1 ))
236
- end
237
- f :close ()
184
+ local a , b = rand (n , n ), rand (n , n )
185
+ return a []** b []
238
186
end
239
187
188
+ timeit (function () return randmatmult (1000 ) end , ' rand_mat_mul' )
240
189
241
- timeit (sortperf , " quicksort" )
242
- timeit (pisum , " pi_sum" )
243
- timeit (|| randmatstat (1000 ), " rand_mat_stat" )
244
- timeit (|| randmatmult (1000 ), " rand_mat_mul" )
245
- -- timeit(|| printfd(100000), "printfd")
190
+ if jit .os ~= ' Windows' then
191
+ local function printfd (n )
192
+ local f = io.open (' /dev/null' ,' w' )
193
+ for i = 1 , n do
194
+ f :write (format (' %d %d\n ' , i , i + 1 ))
195
+ end
196
+ f :close ()
197
+ end
198
+
199
+ timeit (function () return printfd (100000 ) end , ' printfd' )
200
+ end
0 commit comments