Skip to content

Commit e58f3e9

Browse files
Used named arguments instead in model class construction
1 parent 0618360 commit e58f3e9

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

src/Models/Pretrained/PretrainedModel.php

+38-14
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public static function fromPretrained(
8484
string $revision = 'main',
8585
?string $modelFilename = null,
8686
ModelArchitecture $modelArchitecture = ModelArchitecture::EncoderOnly,
87-
?callable $onProgress = null
87+
?callable $onProgress = null
8888
): self
8989
{
9090
if (is_array($config)) {
@@ -115,7 +115,12 @@ public static function fromPretrained(
115115

116116
$generatorConfig = new GenerationConfig($generatorConfigArr);
117117

118-
return new static($config, $session, $modelArchitecture, $generatorConfig);
118+
return new static(
119+
config: $config,
120+
session: $session,
121+
modelArchitecture: $modelArchitecture,
122+
generationConfig: $generatorConfig
123+
);
119124
}
120125

121126
case ModelArchitecture::Seq2SeqLM:
@@ -148,8 +153,13 @@ public static function fromPretrained(
148153

149154
$generatorConfig = new GenerationConfig($generatorConfigArr);
150155

151-
152-
return new static($config, $encoderSession, $decoderSession, $modelArchitecture, $generatorConfig);
156+
return new static(
157+
config: $config,
158+
session: $encoderSession,
159+
modelArchitecture: $modelArchitecture,
160+
generationConfig: $generatorConfig,
161+
decoderMergedSession: $decoderSession
162+
);
153163
}
154164

155165
case ModelArchitecture::MaskGeneration:
@@ -170,7 +180,12 @@ public static function fromPretrained(
170180
onProgress: $onProgress
171181
);
172182

173-
return new static($config, $visionEncoder, $promptMaskEncoder, $modelArchitecture);
183+
return new static(
184+
config: $config,
185+
session: $visionEncoder,
186+
promptMaskEncoderSession: $promptMaskEncoder,
187+
modelArchitecture: $modelArchitecture
188+
);
174189
}
175190

176191
case ModelArchitecture::EncoderDecoder:
@@ -191,7 +206,12 @@ public static function fromPretrained(
191206
onProgress: $onProgress
192207
);
193208

194-
return new static($config, $encoderSession, $decoderSession, $modelArchitecture);
209+
return new static(
210+
config: $config,
211+
session: $encoderSession,
212+
decoderMergedSession: $decoderSession,
213+
modelArchitecture: $modelArchitecture
214+
);
195215
}
196216

197217
default:
@@ -210,7 +230,11 @@ public static function fromPretrained(
210230
);
211231

212232

213-
return new static($config, $session, $modelArchitecture);
233+
return new static(
234+
config: $config,
235+
session: $session,
236+
modelArchitecture: $modelArchitecture
237+
);
214238
}
215239
}
216240
}
@@ -232,14 +256,14 @@ public static function fromPretrained(
232256
*/
233257

234258
public static function constructSession(
235-
string $modelNameOrPath,
236-
string $fileName,
237-
?string $cacheDir = null,
238-
string $revision = 'main',
239-
string $subFolder = 'onnx',
240-
bool $fatal = true,
259+
string $modelNameOrPath,
260+
string $fileName,
261+
?string $cacheDir = null,
262+
string $revision = 'main',
263+
string $subFolder = 'onnx',
264+
bool $fatal = true,
241265
?callable $onProgress = null,
242-
...$sessionOptions
266+
...$sessionOptions
243267
): ?InferenceSession
244268
{
245269
$modelFileName = "$fileName.onnx";

0 commit comments

Comments
 (0)