Skip to content

Commit

Permalink
clean all pylint warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Trace2333 committed Oct 30, 2024
1 parent 6619399 commit 8ea1944
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 117 deletions.
6 changes: 4 additions & 2 deletions mindnlp/trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@
from .trainer import (
DPOTrainer,
DPOConfig,
FDivergenceType
)
FDivergenceType,
_build_tokenized_answer,
_truncate_tokens
)
11 changes: 0 additions & 11 deletions mindnlp/trl/import_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +0,0 @@



def is_peft_available():
return True


def is_wandb_available():
return False


3 changes: 2 additions & 1 deletion mindnlp/trl/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""trl model __init__"""
from .modeling_base import (
PreTrainedModelWrapper,
create_reference_model
)
)
60 changes: 15 additions & 45 deletions mindnlp/trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
"""trl model main class."""
import logging
import os
from copy import deepcopy
from typing import Optional

import mindspore
import mindspore.nn as nn
# from accelerate import PartialState
from mindspore import nn

# from safetensors.torch import load_file as safe_load_file
from ...transformers import GenerationMixin, PreTrainedModel
from ...transformers import PreTrainedModel

from ...peft import (
PeftConfig,
Expand All @@ -40,10 +38,6 @@
"model.layers.{layer}",
]

def is_peft_available():
# use mindnlp internal peft module.
return True


class PreTrainedModelWrapper(nn.Cell):
r"""
Expand Down Expand Up @@ -158,28 +152,17 @@ class and the arguments that are specific to trl models. The kwargs
)
pretrained_kwargs["device_map"] = {"": current_device}

if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig):
if peft_config is not None and not isinstance(peft_config, PeftConfig):
raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.")

# First, load the pre-trained model using the parent-class
# either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM`
if isinstance(pretrained_model_name_or_path, str):
if is_peft_available():
# try:
# # If there is a trained peft adapter in the hub, load its config.
# remote_adapter_config = hf_hub_download(
# pretrained_model_name_or_path,
# "adapter_config.json",
# token=token,
# )
# except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError):
remote_adapter_config = None
else:
remote_adapter_config = None
remote_adapter_config = None

local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json"))

if (local_adapter_present or remote_adapter_config is not None) and is_peft_available():
if local_adapter_present or remote_adapter_config is not None:
if peft_config is not None:
logging.warning(
"`peft_config` argument ignored since a peft config file was found in "
Expand Down Expand Up @@ -236,14 +219,13 @@ class and the arguments that are specific to trl models. The kwargs
f"but is {type(pretrained_model_name_or_path)}"
)

if is_peft_available():
if isinstance(pretrained_model, PeftModel):
is_peft_model = True
# for backward compatibility
if hasattr(pretrained_model, "active_peft_config") and isinstance(
pretrained_model.active_peft_config, PromptLearningConfig
):
raise ValueError("PromptLearningConfig is not supported for PPO training.")
if isinstance(pretrained_model, PeftModel):
is_peft_model = True
# for backward compatibility
if hasattr(pretrained_model, "active_peft_config") and isinstance(
pretrained_model.active_peft_config, PromptLearningConfig
):
raise ValueError("PromptLearningConfig is not supported for PPO training.")

# Add reward modeling adapter if specified
if not is_peft_model and reward_adapter is not None:
Expand Down Expand Up @@ -298,8 +280,7 @@ class and the arguments that are specific to trl models. The kwargs
# use_safe = False

loading_func = mindspore.load_checkpoint
load_kwargs = {} if use_safe else {"map_location": "cpu", "weights_only": True}

load_kwargs = {}
if is_resuming_training:
# if is_sharded:
# # download each file and add it to the state_dict
Expand Down Expand Up @@ -406,11 +387,6 @@ def _split_kwargs(cls, kwargs):
"""
check_peft_kwargs = False

if is_peft_available():
from peft import prepare_model_for_kbit_training

check_peft_kwargs = True

supported_kwargs = {}
unsupported_kwargs = {}
peft_kwargs = {}
Expand All @@ -421,12 +397,6 @@ def _split_kwargs(cls, kwargs):
else:
unsupported_kwargs[key] = value

if check_peft_kwargs:
if key in prepare_model_for_kbit_training.__code__.co_varnames:
peft_kwargs[key] = value
if key in unsupported_kwargs:
unsupported_kwargs.pop(key)

return supported_kwargs, unsupported_kwargs, peft_kwargs

@classmethod
Expand Down Expand Up @@ -474,7 +444,7 @@ def add_and_load_reward_modeling_adapter(
local_filename = filename

loading_func = mindspore.load_checkpoint
load_kwargs = {} if safe_loading else {"map_location": "cpu", "weights_only": True}
load_kwargs = {}

adapter_state_dict = loading_func(local_filename, **load_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion mindnlp/trl/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""trainer init file."""
# from .base import BaseTrainer
from .dpo_trainer import DPOTrainer
from .dpo_trainer import DPOTrainer, _build_tokenized_answer, _truncate_tokens
from .dpo_config import DPOConfig, FDivergenceType
88 changes: 31 additions & 57 deletions mindnlp/trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,34 @@
import random
import warnings
from collections import defaultdict

import numpy
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import partial
from tqdm import tqdm
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, NewType


import numpy
from tqdm import tqdm

import mindspore
from mindspore import amp, nn, ops, Dataset
from huggingface_hub.utils._deprecation import _deprecate_arguments
from mindspore import amp, nn, ops
from mindspore import Tensor
from mindspore.dataset import GeneratorDataset
from huggingface_hub.utils._deprecation import _deprecate_arguments

from ...transformers import (
AutoModelForCausalLM,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from mindspore import Tensor

from ...engine import (
Trainer,
)

from ...transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
from ...engine.callbacks import TrainerCallback
from ...engine.utils import EvalLoopOutput
from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper, create_reference_model
from .callbacks import SyncRefModelCallback
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
Expand All @@ -58,15 +60,7 @@
peft_module_casting_to_bf16,
)


if is_peft_available():
from ...peft import PeftModel, get_peft_model

# if is_wandb_available():
# import wandb

# if is_deepspeed_available():
# import deepspeed
from ...peft import PeftModel, get_peft_model

InputDataClass = NewType("InputDataClass", Any)
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
Expand Down Expand Up @@ -177,8 +171,7 @@ def _adjust_prompt_length(
for k, v in p_tokens.items():
p_tokens[k] = v[:min_len]

num_diff_tokens = sum([a != b for a, b in
zip(c_tokens["prompt_input_ids"], r_tokens["prompt_input_ids"])])
num_diff_tokens = sum(a != b for a, b in zip(c_tokens["prompt_input_ids"], r_tokens["prompt_input_ids"]))
num_diff_len = abs(c_len - r_len)
if num_diff_tokens > 1 or num_diff_len > 1:
raise ValueError(
Expand Down Expand Up @@ -475,8 +468,8 @@ def __init__(
label_pad_token_id: int = -100,
padding_value: Optional[int] = None,
truncation_mode: str = "keep_end",
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
train_dataset: Optional[GeneratorDataset] = None,
eval_dataset: Optional[Union[GeneratorDataset, Dict[str, GeneratorDataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
Expand Down Expand Up @@ -516,17 +509,12 @@ def __init__(
)
else:
model_init_kwargs = args.model_init_kwargs
torch_dtype = model_init_kwargs.get("torch_dtype")
if torch_dtype is not None:
ms_dtype = model_init_kwargs.get("ms_dtype")
if ms_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(mindspore, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, mindspore.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the DPOConfig. Expected a string"
"with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
model_init_kwargs["torch_dtype"] = torch_dtype
if isinstance(ms_dtype, str) and ms_dtype != "auto":
ms_dtype = getattr(mindspore, ms_dtype)
model_init_kwargs["ms_dtype"] = ms_dtype

if ref_model_init_kwargs is not None:
warnings.warn(
Expand All @@ -544,17 +532,12 @@ def __init__(
)
else:
ref_model_init_kwargs = args.ref_model_init_kwargs
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
if torch_dtype is not None:
ms_dtype = ref_model_init_kwargs.get("ms_dtype")
if ms_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(mindspore, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, mindspore.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the DPOConfig. Expected a string"
"with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
ref_model_init_kwargs["torch_dtype"] = torch_dtype
if isinstance(ms_dtype, str) and ms_dtype != "auto":
ms_dtype = getattr(mindspore, ms_dtype)
ref_model_init_kwargs["ms_dtype"] = ms_dtype

if isinstance(model, str):
warnings.warn(
Expand Down Expand Up @@ -582,12 +565,8 @@ def __init__(
)
args.force_use_ref_model = force_use_ref_model

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs,"
"please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:

if peft_config is not None:
# if model is a peft model and we have a peft_config, we merge and unload it first
if isinstance(model, PeftModel):
model = model.merge_and_unload()
Expand Down Expand Up @@ -645,11 +624,6 @@ def make_inputs_require_grad(module, input, output):
"passed will override the one in the `DPOConfig`."
)
args.generate_during_eval = generate_during_eval
if args.generate_during_eval and not is_wandb_available():
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve."
)

if is_encoder_decoder is not None:
warnings.warn(
Expand Down Expand Up @@ -683,7 +657,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.tokenizer = tokenizer

self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.is_peft_model = isinstance(model, PeftModel)
if model_adapter_name is not None:
warnings.warn(
"You passed `model_adapter_name` to the DPOTrainer, the value you"
Expand Down Expand Up @@ -1064,7 +1038,7 @@ def get_train_dataloader(self) -> GeneratorDataset:

def get_eval_dataloader(
self,
eval_dataset: Optional[Dataset] = None
eval_dataset: Optional[GeneratorDataset] = None
) -> GeneratorDataset:
"""
Returns the evaluation [`~torch.utils.data.GeneratorDataset`].
Expand Down Expand Up @@ -1235,7 +1209,7 @@ def concatenated_inputs(
batch["prompt_attention_mask"].repeat(2, 1)
)
concatenated_batch["concatenated_decoder_input_ids"] = mindspore.ops.cat(
[batch["chosen_decoder_input_ids"], batch["rejected_decoder_input_ids"]], dim=0
[batch["chosen_decoder_input_ids"], batch["rejected_decoder_input_ids"]], axis=0
)

if is_vision_model:
Expand Down Expand Up @@ -1299,7 +1273,7 @@ def dpo_loss(
else:
pi_logratios = policy_chosen_logps - policy_rejected_logps
if self.reference_free:
ref_logratios = mindspore.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios)
ref_logratios = mindspore.tensor([0], dtype=pi_logratios.dtype)
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps

Expand All @@ -1320,7 +1294,7 @@ def dpo_loss(
# in the range of 0.1 to 0.5.We ignore the reference model as beta -> 0.
# The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.

if self.loss_type == "sigmoid":
losses = (
- ops.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
Expand Down Expand Up @@ -1771,7 +1745,7 @@ def prediction_step(
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = mindspore.ops.stack(logits).mean(axis=1).to(self.accelerator.device)
labels = mindspore.ops.zeros(logits.shape[0], device=self.accelerator.device)
labels = mindspore.ops.zeros(logits.shape[0])

return (loss, logits, labels)

Expand Down

0 comments on commit 8ea1944

Please sign in to comment.