Skip to content

Commit 0591db6

Browse files
committed
fit: export Func1D.Hessian
Signed-off-by: Sebastien Binet <[email protected]>
1 parent 02e76ba commit 0591db6

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

fit/curve1d_example_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package fit_test
66

77
import (
8+
"fmt"
89
"image/color"
910
"log"
1011
"math"
@@ -13,7 +14,10 @@ import (
1314
"go-hep.org/x/hep/hbook"
1415
"go-hep.org/x/hep/hplot"
1516
"gonum.org/v1/gonum/floats"
17+
"gonum.org/v1/gonum/mat"
1618
"gonum.org/v1/gonum/optimize"
19+
"gonum.org/v1/gonum/stat"
20+
"gonum.org/v1/gonum/stat/distuv"
1721
"gonum.org/v1/plot/plotter"
1822
"gonum.org/v1/plot/vg"
1923
)
@@ -289,3 +293,130 @@ func ExampleCurve1D_powerlaw() {
289293
}
290294
}
291295
}
296+
297+
func ExampleCurve1D_hessian() {
298+
var (
299+
cst = 3.0
300+
mean = 30.0
301+
sigma = 20.0
302+
want = []float64{cst, mean, sigma}
303+
)
304+
305+
xdata, ydata, err := readXY("testdata/gauss-data.txt")
306+
if err != nil {
307+
log.Fatal(err)
308+
}
309+
310+
// use a small sample
311+
xdata = xdata[:min(25, len(xdata))]
312+
ydata = ydata[:min(25, len(ydata))]
313+
314+
gauss := func(x, cst, mu, sigma float64) float64 {
315+
v := (x - mu)
316+
return cst * math.Exp(-v*v/sigma)
317+
}
318+
319+
f1d := fit.Func1D{
320+
F: func(x float64, ps []float64) float64 {
321+
return gauss(x, ps[0], ps[1], ps[2])
322+
},
323+
X: xdata,
324+
Y: ydata,
325+
Ps: []float64{10, 10, 10},
326+
}
327+
res, err := fit.Curve1D(f1d, nil, &optimize.NelderMead{})
328+
if err != nil {
329+
log.Fatal(err)
330+
}
331+
332+
if err := res.Status.Err(); err != nil {
333+
log.Fatal(err)
334+
}
335+
if got := res.X; !floats.EqualApprox(got, want, 1e-3) {
336+
log.Fatalf("got= %v\nwant=%v\n", got, want)
337+
}
338+
339+
inv := mat.NewSymDense(len(res.Location.X), nil)
340+
f1d.Hessian(inv, res.Location.X)
341+
// fmt.Printf("hessian: %1.2e\n", mat.Formatted(inv, mat.Prefix(" ")))
342+
343+
popt := res.Location.X
344+
pcov := mat.NewDense(len(popt), len(popt), nil)
345+
{
346+
var chol mat.Cholesky
347+
if ok := chol.Factorize(inv); !ok {
348+
log.Fatalf("cov-matrix not positive semi-definite")
349+
}
350+
351+
err := chol.InverseTo(inv)
352+
if err != nil {
353+
log.Fatalf("could not inverse matrix: %+v", err)
354+
}
355+
pcov.Copy(inv)
356+
}
357+
358+
// compute goodness-of-fit.
359+
gof := newGoF(f1d.X, f1d.Y, popt, func(x float64) float64 {
360+
return f1d.F(x, popt)
361+
})
362+
363+
pcov.Scale(gof.SSE/float64(len(f1d.X)-len(popt)), pcov)
364+
365+
// fmt.Printf("pcov: %1.2e\n", mat.Formatted(pcov, mat.Prefix(" ")))
366+
367+
var (
368+
n = float64(len(f1d.X)) // number of data points
369+
ndf = n - float64(len(popt)) // number of degrees of freedom
370+
t = distuv.StudentsT{
371+
Mu: 0,
372+
Sigma: 1,
373+
Nu: ndf,
374+
}.Quantile(0.5 * (1 + 0.95))
375+
)
376+
377+
for i, p := range popt {
378+
sigma := math.Sqrt(pcov.At(i, i))
379+
fmt.Printf("c%d: %1.5e [%1.5e, %1.5e] -- truth: %g\n", i, p, p-sigma*t, p+sigma*t, want[i])
380+
}
381+
// Output:
382+
//c0: 2.99999e+00 [2.99999e+00, 3.00000e+00] -- truth: 3
383+
//c1: 3.00000e+01 [3.00000e+01, 3.00000e+01] -- truth: 30
384+
//c2: 2.00000e+01 [2.00000e+01, 2.00000e+01] -- truth: 20
385+
}
386+
387+
type GoF struct {
388+
SSE float64 // Sum of squares due to error
389+
Rsquare float64 // R-Square is the square of the correlation between the response values and the predicted response values
390+
NdF int // Number of degrees of freedom
391+
AdjRsquare float64 // Degrees of freedom adjusted R-Square
392+
RMSE float64 // Root mean squared error
393+
}
394+
395+
func newGoF(xs, ys, ps []float64, f func(float64) float64) GoF {
396+
switch {
397+
case len(xs) != len(ys):
398+
panic("invalid lengths")
399+
}
400+
401+
var gof GoF
402+
403+
var (
404+
ye = make([]float64, len(ys))
405+
nn = float64(len(xs) - 1)
406+
vv = float64(len(xs) - len(ps))
407+
)
408+
409+
for i, x := range xs {
410+
ye[i] = f(x)
411+
dy := ys[i] - ye[i]
412+
gof.SSE += dy * dy
413+
gof.RMSE += dy * dy
414+
}
415+
416+
gof.Rsquare = stat.RSquaredFrom(ye, ys, nil)
417+
gof.AdjRsquare = 1 - ((1 - gof.Rsquare) * nn / vv)
418+
gof.RMSE = math.Sqrt(gof.RMSE / float64(len(ys)-len(ps)))
419+
gof.NdF = len(ys) - len(ps)
420+
421+
return gof
422+
}

fit/fit.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ func (f *Func1D) init() {
8787
}
8888
}
8989

90+
// Hessian computes the hessian matrix at the provided x point.
91+
func (f *Func1D) Hessian(hess *mat.SymDense, x []float64) {
92+
if f.hess == nil {
93+
f.init()
94+
}
95+
f.hess(hess, x)
96+
}
97+
9098
// FuncND describes a multivariate function F(x0, x1... xn; p0, p1... pn)
9199
// for which the parameters ps can be found with a fit.
92100
type FuncND struct {

0 commit comments

Comments
 (0)