Skip to content

Commit

Permalink
Merge pull request #39 from CodeWithKyrian/37-cli-download-model-reso…
Browse files Browse the repository at this point in the history
…lution

Fix bug with Download CLI - use named parameters for model construct
  • Loading branch information
CodeWithKyrian authored May 24, 2024
2 parents 0618360 + e58f3e9 commit ebfb298
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions src/Models/Pretrained/PretrainedModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public static function fromPretrained(
string $revision = 'main',
?string $modelFilename = null,
ModelArchitecture $modelArchitecture = ModelArchitecture::EncoderOnly,
?callable $onProgress = null
?callable $onProgress = null
): self
{
if (is_array($config)) {
Expand Down Expand Up @@ -115,7 +115,12 @@ public static function fromPretrained(

$generatorConfig = new GenerationConfig($generatorConfigArr);

return new static($config, $session, $modelArchitecture, $generatorConfig);
return new static(
config: $config,
session: $session,
modelArchitecture: $modelArchitecture,
generationConfig: $generatorConfig
);
}

case ModelArchitecture::Seq2SeqLM:
Expand Down Expand Up @@ -148,8 +153,13 @@ public static function fromPretrained(

$generatorConfig = new GenerationConfig($generatorConfigArr);


return new static($config, $encoderSession, $decoderSession, $modelArchitecture, $generatorConfig);
return new static(
config: $config,
session: $encoderSession,
modelArchitecture: $modelArchitecture,
generationConfig: $generatorConfig,
decoderMergedSession: $decoderSession
);
}

case ModelArchitecture::MaskGeneration:
Expand All @@ -170,7 +180,12 @@ public static function fromPretrained(
onProgress: $onProgress
);

return new static($config, $visionEncoder, $promptMaskEncoder, $modelArchitecture);
return new static(
config: $config,
session: $visionEncoder,
promptMaskEncoderSession: $promptMaskEncoder,
modelArchitecture: $modelArchitecture
);
}

case ModelArchitecture::EncoderDecoder:
Expand All @@ -191,7 +206,12 @@ public static function fromPretrained(
onProgress: $onProgress
);

return new static($config, $encoderSession, $decoderSession, $modelArchitecture);
return new static(
config: $config,
session: $encoderSession,
decoderMergedSession: $decoderSession,
modelArchitecture: $modelArchitecture
);
}

default:
Expand All @@ -210,7 +230,11 @@ public static function fromPretrained(
);


return new static($config, $session, $modelArchitecture);
return new static(
config: $config,
session: $session,
modelArchitecture: $modelArchitecture
);
}
}
}
Expand All @@ -232,14 +256,14 @@ public static function fromPretrained(
*/

public static function constructSession(
string $modelNameOrPath,
string $fileName,
?string $cacheDir = null,
string $revision = 'main',
string $subFolder = 'onnx',
bool $fatal = true,
string $modelNameOrPath,
string $fileName,
?string $cacheDir = null,
string $revision = 'main',
string $subFolder = 'onnx',
bool $fatal = true,
?callable $onProgress = null,
...$sessionOptions
...$sessionOptions
): ?InferenceSession
{
$modelFileName = "$fileName.onnx";
Expand Down

0 comments on commit ebfb298

Please sign in to comment.