@@ -84,7 +84,7 @@ public static function fromPretrained(
84
84
string $ revision = 'main ' ,
85
85
?string $ modelFilename = null ,
86
86
ModelArchitecture $ modelArchitecture = ModelArchitecture::EncoderOnly,
87
- ?callable $ onProgress = null
87
+ ?callable $ onProgress = null
88
88
): self
89
89
{
90
90
if (is_array ($ config )) {
@@ -115,7 +115,12 @@ public static function fromPretrained(
115
115
116
116
$ generatorConfig = new GenerationConfig ($ generatorConfigArr );
117
117
118
- return new static ($ config , $ session , $ modelArchitecture , $ generatorConfig );
118
+ return new static (
119
+ config: $ config ,
120
+ session: $ session ,
121
+ modelArchitecture: $ modelArchitecture ,
122
+ generationConfig: $ generatorConfig
123
+ );
119
124
}
120
125
121
126
case ModelArchitecture::Seq2SeqLM:
@@ -148,8 +153,13 @@ public static function fromPretrained(
148
153
149
154
$ generatorConfig = new GenerationConfig ($ generatorConfigArr );
150
155
151
-
152
- return new static ($ config , $ encoderSession , $ decoderSession , $ modelArchitecture , $ generatorConfig );
156
+ return new static (
157
+ config: $ config ,
158
+ session: $ encoderSession ,
159
+ modelArchitecture: $ modelArchitecture ,
160
+ generationConfig: $ generatorConfig ,
161
+ decoderMergedSession: $ decoderSession
162
+ );
153
163
}
154
164
155
165
case ModelArchitecture::MaskGeneration:
@@ -170,7 +180,12 @@ public static function fromPretrained(
170
180
onProgress: $ onProgress
171
181
);
172
182
173
- return new static ($ config , $ visionEncoder , $ promptMaskEncoder , $ modelArchitecture );
183
+ return new static (
184
+ config: $ config ,
185
+ session: $ visionEncoder ,
186
+ promptMaskEncoderSession: $ promptMaskEncoder ,
187
+ modelArchitecture: $ modelArchitecture
188
+ );
174
189
}
175
190
176
191
case ModelArchitecture::EncoderDecoder:
@@ -191,7 +206,12 @@ public static function fromPretrained(
191
206
onProgress: $ onProgress
192
207
);
193
208
194
- return new static ($ config , $ encoderSession , $ decoderSession , $ modelArchitecture );
209
+ return new static (
210
+ config: $ config ,
211
+ session: $ encoderSession ,
212
+ decoderMergedSession: $ decoderSession ,
213
+ modelArchitecture: $ modelArchitecture
214
+ );
195
215
}
196
216
197
217
default :
@@ -210,7 +230,11 @@ public static function fromPretrained(
210
230
);
211
231
212
232
213
- return new static ($ config , $ session , $ modelArchitecture );
233
+ return new static (
234
+ config: $ config ,
235
+ session: $ session ,
236
+ modelArchitecture: $ modelArchitecture
237
+ );
214
238
}
215
239
}
216
240
}
@@ -232,14 +256,14 @@ public static function fromPretrained(
232
256
*/
233
257
234
258
public static function constructSession (
235
- string $ modelNameOrPath ,
236
- string $ fileName ,
237
- ?string $ cacheDir = null ,
238
- string $ revision = 'main ' ,
239
- string $ subFolder = 'onnx ' ,
240
- bool $ fatal = true ,
259
+ string $ modelNameOrPath ,
260
+ string $ fileName ,
261
+ ?string $ cacheDir = null ,
262
+ string $ revision = 'main ' ,
263
+ string $ subFolder = 'onnx ' ,
264
+ bool $ fatal = true ,
241
265
?callable $ onProgress = null ,
242
- ...$ sessionOptions
266
+ ...$ sessionOptions
243
267
): ?InferenceSession
244
268
{
245
269
$ modelFileName = "$ fileName.onnx " ;
0 commit comments