Skip to content
This repository was archived by the owner on May 14, 2023. It is now read-only.

Commit a376476

Browse files
authored
Merge pull request #77 from itzamna314/master
Adds ModelFromFS
2 parents 6d09853 + ddaf0b7 commit a376476

File tree

4 files changed

+95
-19
lines changed

4 files changed

+95
-19
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/jdkato/prose/v3
22

3-
go 1.13
3+
go 1.16
44

55
require (
66
github.com/neurosnap/sentences v1.0.6 // indirect

model.go

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package prose
22

33
import (
4+
"io"
5+
"io/fs"
46
"os"
57
"path/filepath"
68
)
@@ -60,12 +62,45 @@ func ModelFromData(name string, sources ...DataSource) *Model {
6062

6163
// ModelFromDisk loads a Model from the user-provided location.
6264
func ModelFromDisk(path string) *Model {
63-
name, classifier := loadClassifier(path)
65+
filesys := os.DirFS(path)
66+
return &Model{
67+
Name: filepath.Base(path),
68+
69+
extracter: loadClassifier(filesys),
70+
tagger: newPerceptronTagger(),
71+
}
72+
}
73+
74+
// ModelFromFS loads a model from the
75+
func ModelFromFS(name string, filesys fs.FS) *Model {
76+
// Locate a folder matching name within filesys
77+
var modelFS fs.FS
78+
err := fs.WalkDir(filesys, ".", func(path string, d fs.DirEntry, err error) error {
79+
if err != nil {
80+
return err
81+
}
82+
83+
// Model located. Exit tree traversal
84+
if d.Name() == name {
85+
modelFS, err = fs.Sub(filesys, path)
86+
if err != nil {
87+
return err
88+
}
89+
return io.EOF
90+
}
91+
92+
return nil
93+
})
94+
if err != io.EOF {
95+
checkError(err)
96+
}
97+
6498
return &Model{
6599
Name: name,
66100

67-
extracter: classifier,
68-
tagger: newPerceptronTagger()}
101+
extracter: loadClassifier(modelFS),
102+
tagger: newPerceptronTagger(),
103+
}
69104
}
70105

71106
// Write saves a Model to the user-provided location.
@@ -96,24 +131,28 @@ func loadTagger(path string) *perceptronTagger {
96131
return newTrainedPerceptronTagger(model)
97132
}*/
98133

99-
func loadClassifier(path string) (string, *entityExtracter) {
134+
func loadClassifier(filesys fs.FS) *entityExtracter {
100135
var mapping map[string]int
101136
var weights []float64
102137
var labels []string
103138

104-
loc := filepath.Join(path, "Maxent")
105-
dec := getDiskAsset(filepath.Join(loc, "mapping.gob"))
106-
checkError(dec.Decode(&mapping))
139+
maxent, err := fs.Sub(filesys, "Maxent")
140+
checkError(err)
141+
142+
file, err := maxent.Open("mapping.gob")
143+
checkError(err)
144+
checkError(getDiskAsset(file).Decode(&mapping))
107145

108-
dec = getDiskAsset(filepath.Join(loc, "weights.gob"))
109-
checkError(dec.Decode(&weights))
146+
file, err = maxent.Open("weights.gob")
147+
checkError(err)
148+
checkError(getDiskAsset(file).Decode(&weights))
110149

111-
dec = getDiskAsset(filepath.Join(loc, "labels.gob"))
112-
checkError(dec.Decode(&labels))
150+
file, err = maxent.Open("labels.gob")
151+
checkError(err)
152+
checkError(getDiskAsset(file).Decode(&labels))
113153

114154
model := newMaxentClassifier(weights, mapping, labels)
115-
name := filepath.Base(path)
116-
return name, newTrainedEntityExtracter(model)
155+
return newTrainedEntityExtracter(model)
117156
}
118157

119158
func defaultModel(tagging, classifying bool) *Model {

model_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package prose
22

33
import (
4+
"embed"
5+
"io/fs"
46
"os"
57
"path/filepath"
68
"testing"
@@ -26,3 +28,40 @@ func TestModelFromDisk(t *testing.T) {
2628
t.Errorf("ModelFromDisk() expected = temp, got = %v", model.Name)
2729
}
2830
}
31+
32+
//go:embed testdata/PRODUCT
33+
var embeddedModel embed.FS
34+
35+
func TestModelFromFS(t *testing.T) {
36+
err := fs.WalkDir(embeddedModel, ".", func(path string, d fs.DirEntry, err error) error {
37+
//fmt.Printf("Walking dir %s, err %s\n", path, err)
38+
return nil
39+
})
40+
41+
// Load the embedded PRODUCT model
42+
model := ModelFromFS("PRODUCT", embeddedModel)
43+
if model.Name != "PRODUCT" {
44+
t.Errorf("ModelFromFS() expected = PRODUCT, got = %v", model.Name)
45+
}
46+
47+
doc, err := NewDocument("Windows 10 is an operating system",
48+
UsingModel(model))
49+
50+
if err != nil {
51+
t.Errorf("Failed to create doc with ModelFromFS")
52+
}
53+
54+
ents := doc.Entities()
55+
56+
if len(ents) != 1 {
57+
t.Fatalf("Expected 1 entity, got %v", ents)
58+
}
59+
60+
if ents[0].Text != "Windows 10" {
61+
t.Errorf("Expected to find entity 'Windows 10' with ModelFromFS, got = %v", ents[0].Text)
62+
}
63+
64+
if ents[0].Label != "PRODUCT" {
65+
t.Errorf("Expected to tab entity with PRODUCT, got = %v", ents[0].Label)
66+
}
67+
}

utilities.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package prose
33
import (
44
"bytes"
55
"encoding/gob"
6-
"os"
6+
"io/fs"
77
"path"
88
"strconv"
99
"strings"
@@ -46,10 +46,8 @@ func getAsset(folder, name string) *gob.Decoder {
4646
return gob.NewDecoder(bytes.NewReader(b))
4747
}
4848

49-
func getDiskAsset(path string) *gob.Decoder {
50-
f, err := os.Open(path)
51-
checkError(err)
52-
return gob.NewDecoder(f)
49+
func getDiskAsset(file fs.File) *gob.Decoder {
50+
return gob.NewDecoder(file)
5351
}
5452

5553
func hasAnyPrefix(s string, prefixes []string) bool {

0 commit comments

Comments
 (0)