|
1 | 1 | package prose |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "io" |
| 5 | + "io/fs" |
4 | 6 | "os" |
5 | 7 | "path/filepath" |
6 | 8 | ) |
@@ -60,12 +62,45 @@ func ModelFromData(name string, sources ...DataSource) *Model { |
60 | 62 |
|
61 | 63 | // ModelFromDisk loads a Model from the user-provided location. |
62 | 64 | func ModelFromDisk(path string) *Model { |
63 | | - name, classifier := loadClassifier(path) |
| 65 | + filesys := os.DirFS(path) |
| 66 | + return &Model{ |
| 67 | + Name: filepath.Base(path), |
| 68 | + |
| 69 | + extracter: loadClassifier(filesys), |
| 70 | + tagger: newPerceptronTagger(), |
| 71 | + } |
| 72 | +} |
| 73 | + |
| 74 | +// ModelFromFS loads a model from the |
| 75 | +func ModelFromFS(name string, filesys fs.FS) *Model { |
| 76 | + // Locate a folder matching name within filesys |
| 77 | + var modelFS fs.FS |
| 78 | + err := fs.WalkDir(filesys, ".", func(path string, d fs.DirEntry, err error) error { |
| 79 | + if err != nil { |
| 80 | + return err |
| 81 | + } |
| 82 | + |
| 83 | + // Model located. Exit tree traversal |
| 84 | + if d.Name() == name { |
| 85 | + modelFS, err = fs.Sub(filesys, path) |
| 86 | + if err != nil { |
| 87 | + return err |
| 88 | + } |
| 89 | + return io.EOF |
| 90 | + } |
| 91 | + |
| 92 | + return nil |
| 93 | + }) |
| 94 | + if err != io.EOF { |
| 95 | + checkError(err) |
| 96 | + } |
| 97 | + |
64 | 98 | return &Model{ |
65 | 99 | Name: name, |
66 | 100 |
|
67 | | - extracter: classifier, |
68 | | - tagger: newPerceptronTagger()} |
| 101 | + extracter: loadClassifier(modelFS), |
| 102 | + tagger: newPerceptronTagger(), |
| 103 | + } |
69 | 104 | } |
70 | 105 |
|
71 | 106 | // Write saves a Model to the user-provided location. |
@@ -96,24 +131,28 @@ func loadTagger(path string) *perceptronTagger { |
96 | 131 | return newTrainedPerceptronTagger(model) |
97 | 132 | }*/ |
98 | 133 |
|
99 | | -func loadClassifier(path string) (string, *entityExtracter) { |
| 134 | +func loadClassifier(filesys fs.FS) *entityExtracter { |
100 | 135 | var mapping map[string]int |
101 | 136 | var weights []float64 |
102 | 137 | var labels []string |
103 | 138 |
|
104 | | - loc := filepath.Join(path, "Maxent") |
105 | | - dec := getDiskAsset(filepath.Join(loc, "mapping.gob")) |
106 | | - checkError(dec.Decode(&mapping)) |
| 139 | + maxent, err := fs.Sub(filesys, "Maxent") |
| 140 | + checkError(err) |
| 141 | + |
| 142 | + file, err := maxent.Open("mapping.gob") |
| 143 | + checkError(err) |
| 144 | + checkError(getDiskAsset(file).Decode(&mapping)) |
107 | 145 |
|
108 | | - dec = getDiskAsset(filepath.Join(loc, "weights.gob")) |
109 | | - checkError(dec.Decode(&weights)) |
| 146 | + file, err = maxent.Open("weights.gob") |
| 147 | + checkError(err) |
| 148 | + checkError(getDiskAsset(file).Decode(&weights)) |
110 | 149 |
|
111 | | - dec = getDiskAsset(filepath.Join(loc, "labels.gob")) |
112 | | - checkError(dec.Decode(&labels)) |
| 150 | + file, err = maxent.Open("labels.gob") |
| 151 | + checkError(err) |
| 152 | + checkError(getDiskAsset(file).Decode(&labels)) |
113 | 153 |
|
114 | 154 | model := newMaxentClassifier(weights, mapping, labels) |
115 | | - name := filepath.Base(path) |
116 | | - return name, newTrainedEntityExtracter(model) |
| 155 | + return newTrainedEntityExtracter(model) |
117 | 156 | } |
118 | 157 |
|
119 | 158 | func defaultModel(tagging, classifying bool) *Model { |
|
0 commit comments