@@ -800,7 +800,7 @@ def from_pretrained( # pylint: disable=too-many-locals
800
800
"""from_pretrained"""
801
801
state_dict = kwargs .pop ("state_dict" , None )
802
802
cache_dir = kwargs .pop ("cache_dir" , None )
803
- from_pt = kwargs .pop ("from_pt" , True )
803
+ _ = kwargs .pop ("from_pt" , True )
804
804
force_download = kwargs .pop ("force_download" , False )
805
805
resume_download = kwargs .pop ("resume_download" , False )
806
806
proxies = kwargs .pop ("proxies" , None )
@@ -839,7 +839,7 @@ def from_pretrained( # pylint: disable=too-many-locals
839
839
pretrained_model_name_or_path = str (pretrained_model_name_or_path )
840
840
is_local = os .path .isdir (pretrained_model_name_or_path )
841
841
if is_local :
842
- if from_pt and os .path .isfile (
842
+ if os .path .isfile (
843
843
os .path .join (pretrained_model_name_or_path , subfolder , PT_WEIGHTS_NAME )
844
844
):
845
845
# Load from a PyTorch checkpoint
@@ -858,7 +858,7 @@ def from_pretrained( # pylint: disable=too-many-locals
858
858
archive_file = os .path .join (
859
859
pretrained_model_name_or_path , subfolder , _add_variant (SAFE_WEIGHTS_NAME , variant )
860
860
)
861
- elif from_pt and os .path .isfile (
861
+ elif os .path .isfile (
862
862
os .path .join (pretrained_model_name_or_path , subfolder , _add_variant (PT_WEIGHTS_INDEX_NAME , variant ))
863
863
):
864
864
# Load from a sharded PyTorch checkpoint
@@ -901,11 +901,12 @@ def from_pretrained( # pylint: disable=too-many-locals
901
901
elif is_remote_url (pretrained_model_name_or_path ):
902
902
filename = pretrained_model_name_or_path
903
903
resolved_archive_file = download_url (pretrained_model_name_or_path )
904
- elif from_pt :
904
+ else :
905
905
if use_safetensors is not False :
906
906
filename = _add_variant (SAFE_WEIGHTS_NAME , variant )
907
907
else :
908
- filename = _add_variant (PT_WEIGHTS_NAME , variant )
908
+ filename = _add_variant (WEIGHTS_NAME , variant )
909
+
909
910
try :
910
911
# Load from URL or cache if already cached
911
912
cached_file_kwargs = {
@@ -935,68 +936,30 @@ def from_pretrained( # pylint: disable=too-many-locals
935
936
if resolved_archive_file is not None :
936
937
is_sharded = True
937
938
use_safetensors = True
938
- else :
939
- # This repo has no safetensors file of any kind, we switch to PyTorch.
940
- filename = _add_variant (PT_WEIGHTS_NAME , variant )
941
- resolved_archive_file = cached_file (
942
- pretrained_model_name_or_path , filename , ** cached_file_kwargs
943
- )
944
939
945
940
if resolved_archive_file is None :
946
- filename = _add_variant (PT_WEIGHTS_NAME , variant )
941
+ filename = _add_variant (WEIGHTS_NAME , variant )
947
942
resolved_archive_file = cached_file (pretrained_model_name_or_path , filename , ** cached_file_kwargs )
948
943
949
- if resolved_archive_file is None and filename == _add_variant (PT_WEIGHTS_NAME , variant ):
944
+ if resolved_archive_file is None and filename == _add_variant (WEIGHTS_NAME , variant ):
950
945
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
951
946
resolved_archive_file = cached_file (
952
947
pretrained_model_name_or_path ,
953
- _add_variant (PT_WEIGHTS_INDEX_NAME , variant ),
948
+ _add_variant (WEIGHTS_INDEX_NAME , variant ),
954
949
** cached_file_kwargs ,
955
950
)
956
951
if resolved_archive_file is not None :
957
952
is_sharded = True
958
953
959
954
if resolved_archive_file is None :
960
- raise EnvironmentError (
961
- f"{ pretrained_model_name_or_path } does not appear to have a file named"
962
- f" { _add_variant (SAFE_WEIGHTS_NAME , variant )} , { _add_variant (PT_WEIGHTS_NAME , variant )} "
963
- )
964
- except EnvironmentError :
965
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
966
- # to the original exception.
967
- raise
968
- except Exception as exc :
969
- # For any other exception, we throw a generic error.
970
- raise EnvironmentError (
971
- f"Can't load the model for '{ pretrained_model_name_or_path } '. If you were trying to load it"
972
- ", make sure you don't have a local directory with the"
973
- f" same name. Otherwise, make sure '{ pretrained_model_name_or_path } ' is the correct path to a"
974
- f" directory containing a file named { _add_variant (SAFE_WEIGHTS_NAME , variant )} ,"
975
- f" { _add_variant (PT_WEIGHTS_NAME , variant )} ."
976
- ) from exc
977
- else :
978
- # set correct filename
979
- filename = _add_variant (WEIGHTS_NAME , variant )
980
- try :
981
- # Load from URL or cache if already cached
982
- cached_file_kwargs = {
983
- "cache_dir" : cache_dir ,
984
- "force_download" : force_download ,
985
- "proxies" : proxies ,
986
- "resume_download" : resume_download ,
987
- "local_files_only" : local_files_only ,
988
- "subfolder" : subfolder ,
989
- "_raise_exceptions_for_missing_entries" : False ,
990
- 'token' : token
991
- }
992
-
993
- resolved_archive_file = cached_file (pretrained_model_name_or_path , filename , ** cached_file_kwargs )
955
+ filename = _add_variant (PT_WEIGHTS_NAME , variant )
956
+ resolved_archive_file = cached_file (pretrained_model_name_or_path , filename , ** cached_file_kwargs )
994
957
995
- if resolved_archive_file is None and filename == _add_variant (WEIGHTS_NAME , variant ):
958
+ if resolved_archive_file is None and filename == _add_variant (PT_WEIGHTS_NAME , variant ):
996
959
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
997
960
resolved_archive_file = cached_file (
998
961
pretrained_model_name_or_path ,
999
- _add_variant (WEIGHTS_INDEX_NAME , variant ),
962
+ _add_variant (PT_WEIGHTS_INDEX_NAME , variant ),
1000
963
** cached_file_kwargs ,
1001
964
)
1002
965
if resolved_archive_file is not None :
@@ -1005,7 +968,7 @@ def from_pretrained( # pylint: disable=too-many-locals
1005
968
if resolved_archive_file is None :
1006
969
raise EnvironmentError (
1007
970
f"{ pretrained_model_name_or_path } does not appear to have a file named"
1008
- f" { _add_variant (WEIGHTS_NAME , variant )} . "
971
+ f" { _add_variant (SAFE_WEIGHTS_NAME , variant )} , { _add_variant ( PT_WEIGHTS_NAME , variant ) } "
1009
972
)
1010
973
except EnvironmentError :
1011
974
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
@@ -1017,7 +980,8 @@ def from_pretrained( # pylint: disable=too-many-locals
1017
980
f"Can't load the model for '{ pretrained_model_name_or_path } '. If you were trying to load it"
1018
981
", make sure you don't have a local directory with the"
1019
982
f" same name. Otherwise, make sure '{ pretrained_model_name_or_path } ' is the correct path to a"
1020
- f" directory containing a file named { _add_variant (WEIGHTS_NAME , variant )} ."
983
+ f" directory containing a file named { _add_variant (WEIGHTS_NAME , variant )} , { _add_variant (SAFE_WEIGHTS_NAME , variant )} ,"
984
+ f" { _add_variant (PT_WEIGHTS_NAME , variant )} ."
1021
985
) from exc
1022
986
1023
987
if is_local :
@@ -1091,8 +1055,8 @@ def empty_initializer(init, shape=None, dtype=mindspore.float32):
1091
1055
# These are all the pointers of shared tensors.
1092
1056
tied_params = [names for _ , names in ptrs .items () if len (names ) > 1 ]
1093
1057
1094
- def load_ckpt (resolved_archive_file , from_pt = False ):
1095
- if from_pt and 'ckpt' not in resolved_archive_file :
1058
+ def load_ckpt (resolved_archive_file ):
1059
+ if 'ckpt' not in resolved_archive_file :
1096
1060
if use_safetensors :
1097
1061
from safetensors .numpy import load_file
1098
1062
origin_state_dict = load_file (resolved_archive_file )
@@ -1214,14 +1178,14 @@ def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str):
1214
1178
if is_sharded :
1215
1179
all_keys_unexpected = []
1216
1180
for name in tqdm (converted_filenames , desc = "Loading checkpoint shards" ):
1217
- state_dict = load_ckpt (name , from_pt )
1181
+ state_dict = load_ckpt (name )
1218
1182
keys_unexpected , keys_missing = load_param_into_net (model , state_dict , cls .base_model_prefix )
1219
1183
all_keys_unexpected .extend (keys_unexpected )
1220
1184
del state_dict
1221
1185
gc .collect ()
1222
1186
loaded_keys = sharded_metadata ["all_checkpoint_keys" ]
1223
1187
else :
1224
- state_dict = load_ckpt (resolved_archive_file , from_pt )
1188
+ state_dict = load_ckpt (resolved_archive_file )
1225
1189
loaded_keys = list (state_dict .keys ())
1226
1190
all_keys_unexpected , keys_missing = load_param_into_net (model , state_dict , cls .base_model_prefix )
1227
1191
else :
@@ -1266,7 +1230,6 @@ def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str):
1266
1230
# Set model in evaluation mode to deactivate DropOut modules by default
1267
1231
model .set_train (False )
1268
1232
1269
- kwargs ['from_pt' ] = from_pt
1270
1233
# If it is a model with generation capabilities, attempt to load the generation config
1271
1234
if model .can_generate () and pretrained_model_name_or_path is not None :
1272
1235
try :
0 commit comments