-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathCharacterIterator.java
233 lines (199 loc) · 8.49 KB
/
CharacterIterator.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.*;
/** A simple DataSetIterator for use in the GravesLSTMCharModellingExample.
* Given a text file and a few options, generate feature vectors and labels for training,
* where we want to predict the next character in the sequence.<br>
* This is done by randomly choosing a position in the text file, at offsets of 0, exampleLength, 2*exampleLength, etc
* to start each sequence. Then we convert each character to an index, i.e., a one-hot vector.
* Then the character 'a' becomes [1,0,0,0,...], 'b' becomes [0,1,0,0,...], etc
*
* Feature vectors and labels are both one-hot vectors of same length
* @author Alex Black
*/
public class CharacterIterator implements DataSetIterator {
//Valid characters
private char[] validCharacters;
//Maps each character to an index ind the input/output
private Map<Character,Integer> charToIdxMap;
//All characters of the input file (after filtering to only those that are valid
private char[] fileCharacters;
//Length of each example/minibatch (number of characters)
private int exampleLength;
//Size of each minibatch (number of examples)
private int miniBatchSize;
private Random rng;
//Offsets for the start of each example
private LinkedList<Integer> exampleStartOffsets = new LinkedList<>();
/**
* @param textFilePath Path to text file to use for generating samples
* @param textFileEncoding Encoding of the text file. Can try Charset.defaultCharset()
* @param miniBatchSize Number of examples per mini-batch
* @param exampleLength Number of characters in each input/output vector
* @param validCharacters Character array of valid characters. Characters not present in this array will be removed
* @param rng Random number generator, for repeatability if required
* @throws IOException If text file cannot be loaded
*/
public CharacterIterator(String textFilePath, Charset textFileEncoding, int miniBatchSize, int exampleLength,
char[] validCharacters, Random rng) throws IOException {
if( !new File(textFilePath).exists()) throw new IOException("Could not access file (does not exist): " + textFilePath);
if( miniBatchSize <= 0 ) throw new IllegalArgumentException("Invalid miniBatchSize (must be >0)");
this.validCharacters = validCharacters;
this.exampleLength = exampleLength;
this.miniBatchSize = miniBatchSize;
this.rng = rng;
//Store valid characters is a map for later use in vectorization
charToIdxMap = new HashMap<>();
for( int i=0; i<validCharacters.length; i++ ) charToIdxMap.put(validCharacters[i], i);
//Load file and convert contents to a char[]
boolean newLineValid = charToIdxMap.containsKey('\n');
List<String> lines = Files.readAllLines(new File(textFilePath).toPath(),textFileEncoding);
int maxSize = lines.size(); //add lines.size() to account for newline characters at end of each line
for( String s : lines ) maxSize += s.length();
char[] characters = new char[maxSize];
int currIdx = 0;
for( String s : lines ){
char[] thisLine = s.toCharArray();
for (char aThisLine : thisLine) {
if (!charToIdxMap.containsKey(aThisLine)) continue;
characters[currIdx++] = aThisLine;
}
if(newLineValid) characters[currIdx++] = '\n';
}
if( currIdx == characters.length ){
fileCharacters = characters;
} else {
fileCharacters = Arrays.copyOfRange(characters, 0, currIdx);
}
if( exampleLength >= fileCharacters.length ) throw new IllegalArgumentException("exampleLength="+exampleLength
+" cannot exceed number of valid characters in file ("+fileCharacters.length+")");
int nRemoved = maxSize - fileCharacters.length;
System.out.println("Loaded and converted file: " + fileCharacters.length + " valid characters of "
+ maxSize + " total characters (" + nRemoved + " removed)");
initializeOffsets();
}
/** A minimal character set, with a-z, A-Z, 0-9 and common punctuation etc */
public static char[] getMinimalCharacterSet(){
List<Character> validChars = new LinkedList<>();
for(char c='a'; c<='z'; c++) validChars.add(c);
for(char c='A'; c<='Z'; c++) validChars.add(c);
for(char c='0'; c<='9'; c++) validChars.add(c);
char[] temp = {'!', '&', '(', ')', '?', '-', '\'', '"', ',', '.', ':', ';', ' ', '\n', '\t'};
for( char c : temp ) validChars.add(c);
char[] out = new char[validChars.size()];
int i=0;
for( Character c : validChars ) out[i++] = c;
return out;
}
/** As per getMinimalCharacterSet(), but with a few extra characters */
public static char[] getDefaultCharacterSet(){
List<Character> validChars = new LinkedList<>();
for(char c : getMinimalCharacterSet() ) validChars.add(c);
char[] additionalChars = {'@', '#', '$', '%', '^', '*', '{', '}', '[', ']', '/', '+', '_',
'\\', '|', '<', '>'};
for( char c : additionalChars ) validChars.add(c);
char[] out = new char[validChars.size()];
int i=0;
for( Character c : validChars ) out[i++] = c;
return out;
}
public char convertIndexToCharacter( int idx ){
return validCharacters[idx];
}
public int convertCharacterToIndex( char c ){
return charToIdxMap.get(c);
}
public char getRandomCharacter(){
return validCharacters[(int) (rng.nextDouble()*validCharacters.length)];
}
public boolean hasNext() {
return exampleStartOffsets.size() > 0;
}
public DataSet next() {
return next(miniBatchSize);
}
public DataSet next(int num) {
if( exampleStartOffsets.size() == 0 ) throw new NoSuchElementException();
int currMinibatchSize = Math.min(num, exampleStartOffsets.size());
//Allocate space:
//Note the order here:
// dimension 0 = number of examples in minibatch
// dimension 1 = size of each vector (i.e., number of characters)
// dimension 2 = length of each time series/example
//Why 'f' order here? See http://deeplearning4j.org/usingrnns.html#data section "Alternative: Implementing a custom DataSetIterator"
INDArray input = Nd4j.create(new int[]{currMinibatchSize,validCharacters.length,exampleLength}, 'f');
INDArray labels = Nd4j.create(new int[]{currMinibatchSize,validCharacters.length,exampleLength}, 'f');
for( int i=0; i<currMinibatchSize; i++ ){
int startIdx = exampleStartOffsets.removeFirst();
int endIdx = startIdx + exampleLength;
int currCharIdx = charToIdxMap.get(fileCharacters[startIdx]); //Current input
int c=0;
for( int j=startIdx+1; j<endIdx; j++, c++ ){
int nextCharIdx = charToIdxMap.get(fileCharacters[j]); //Next character to predict
input.putScalar(new int[]{i,currCharIdx,c}, 1.0);
labels.putScalar(new int[]{i,nextCharIdx,c}, 1.0);
currCharIdx = nextCharIdx;
}
}
return new DataSet(input,labels);
}
public int totalExamples() {
return (fileCharacters.length-1) / miniBatchSize - 2;
}
public int inputColumns() {
return validCharacters.length;
}
public int totalOutcomes() {
return validCharacters.length;
}
public void reset() {
exampleStartOffsets.clear();
initializeOffsets();
}
private void initializeOffsets() {
//This defines the order in which parts of the file are fetched
int nMinibatchesPerEpoch = (fileCharacters.length - 1) / exampleLength - 2; //-2: for end index, and for partial example
for (int i = 0; i < nMinibatchesPerEpoch; i++) {
exampleStartOffsets.add(i * exampleLength);
}
Collections.shuffle(exampleStartOffsets, rng);
}
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
public int batch() {
return miniBatchSize;
}
public int cursor() {
return totalExamples() - exampleStartOffsets.size();
}
public int numExamples() {
return totalExamples();
}
public void setPreProcessor(DataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public DataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public List<String> getLabels() {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}