Skip to content
This repository was archived by the owner on May 14, 2023. It is now read-only.

Commit 597d435

Browse files
committed
Add TestSumLogs
1 parent d026a44 commit 597d435

File tree

3 files changed

+73
-69
lines changed

3 files changed

+73
-69
lines changed

extract.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,53 @@ import (
1414
"gonum.org/v1/gonum/mat"
1515
)
1616

17+
var maxLogDiff = math.Log2(1e-30)
18+
19+
type mappedProbDist struct {
20+
dict map[string]float64
21+
log bool
22+
}
23+
24+
func (m *mappedProbDist) prob(label string) float64 {
25+
if p, found := m.dict[label]; found {
26+
return math.Pow(2, p)
27+
}
28+
return 0.0
29+
}
30+
31+
func (m *mappedProbDist) max() string {
32+
var class string
33+
max := math.Inf(-1)
34+
for label, value := range m.dict {
35+
if value > max {
36+
max = value
37+
class = label
38+
}
39+
}
40+
return class
41+
}
42+
43+
func newMappedProbDist(dict map[string]float64, normalize bool) *mappedProbDist {
44+
if normalize {
45+
values := []float64{}
46+
for _, v := range dict {
47+
values = append(values, v)
48+
}
49+
sum := sumLogs(values)
50+
if sum <= math.Inf(-1) {
51+
p := math.Log2(1.0 / float64(len(dict)))
52+
for k := range dict {
53+
dict[k] = p
54+
}
55+
} else {
56+
for k := range dict {
57+
dict[k] -= sum
58+
}
59+
}
60+
}
61+
return &mappedProbDist{dict: dict, log: true}
62+
}
63+
1764
type encodedValue struct {
1865
key int
1966
value int
@@ -480,3 +527,24 @@ func empiricalCount(corpus featureSet, encoding *binaryMaxentClassifier) *mat.Ve
480527
}
481528
return count
482529
}
530+
531+
func addLogs(x, y float64) float64 {
532+
if x < y+maxLogDiff {
533+
return y
534+
} else if y < x+maxLogDiff {
535+
return x
536+
}
537+
base := math.Min(x, y)
538+
return base + math.Log2(math.Pow(2, x-base)+math.Pow(2, y-base))
539+
}
540+
541+
func sumLogs(logs []float64) float64 {
542+
if len(logs) == 0 {
543+
return math.Inf(-1)
544+
}
545+
sum := logs[0]
546+
for _, log := range logs[1:] {
547+
sum = addLogs(sum, log)
548+
}
549+
return sum
550+
}

extract_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"io"
77
"io/ioutil"
8+
"math"
89
"path/filepath"
910
"reflect"
1011
"testing"
@@ -59,6 +60,10 @@ func split(data []prodigyOuput) ([]EntityContext, []prodigyOuput) {
5960
return train, test
6061
}
6162

63+
func TestSumLogs(t *testing.T) {
64+
assert.Equal(t, 3.0, sumLogs([]float64{math.Log2(3), math.Log2(5)}))
65+
}
66+
6267
func TestNERProdigy(t *testing.T) {
6368
data := filepath.Join(testdata, "reddit_product.jsonl")
6469

utilities.go

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,81 +3,12 @@ package prose
33
import (
44
"bytes"
55
"encoding/gob"
6-
"math"
76
"os"
87
"path"
98
"strconv"
109
"strings"
1110
)
1211

13-
var maxLogDiff = math.Log2(1e-30)
14-
15-
type mappedProbDist struct {
16-
dict map[string]float64
17-
log bool
18-
}
19-
20-
func (m *mappedProbDist) prob(label string) float64 {
21-
if p, found := m.dict[label]; found {
22-
return math.Pow(2, p)
23-
}
24-
return 0.0
25-
}
26-
27-
func (m *mappedProbDist) max() string {
28-
var class string
29-
max := math.Inf(-1)
30-
for label, value := range m.dict {
31-
if value > max {
32-
max = value
33-
class = label
34-
}
35-
}
36-
return class
37-
}
38-
39-
func newMappedProbDist(dict map[string]float64, normalize bool) *mappedProbDist {
40-
if normalize {
41-
values := []float64{}
42-
for _, v := range dict {
43-
values = append(values, v)
44-
}
45-
sum := sumLogs(values)
46-
if sum <= math.Inf(-1) {
47-
p := math.Log2(1.0 / float64(len(dict)))
48-
for k := range dict {
49-
dict[k] = p
50-
}
51-
} else {
52-
for k := range dict {
53-
dict[k] -= sum
54-
}
55-
}
56-
}
57-
return &mappedProbDist{dict: dict, log: true}
58-
}
59-
60-
func addLogs(x, y float64) float64 {
61-
if x < y+maxLogDiff {
62-
return y
63-
} else if y < x+maxLogDiff {
64-
return x
65-
}
66-
base := math.Min(x, y)
67-
return base + math.Log2(math.Pow(2, x-base)+math.Pow(2, y-base))
68-
}
69-
70-
func sumLogs(logs []float64) float64 {
71-
if len(logs) == 0 {
72-
return math.Inf(-1)
73-
}
74-
sum := logs[0]
75-
for _, log := range logs[1:] {
76-
sum = addLogs(sum, log)
77-
}
78-
return sum
79-
}
80-
8112
// checkError panics if `err` is not `nil`.
8213
func checkError(err error) {
8314
if err != nil {

0 commit comments

Comments
 (0)