Skip to content

Conversation

@AgentDS
Copy link

@AgentDS AgentDS commented Oct 26, 2025

add non_blocking arg in ListWrapper.to to support transformers>=4.48.0.

For transformers<4.48.0, the code of transformers.BatchEncoding.to in src/transformers/tokenization_utils_base.py:

def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
    """
    Send all values to device by calling `v.to(device)` (PyTorch only).

    Args:
        device (`str` or `torch.device`): The device to put the tensors on.

    Returns:
        [`BatchEncoding`]: The same instance after modification.
    """
    requires_backends(self, ["torch"])
    import torch

    # This check catches things like APEX blindly calling "to" on all inputs to a module
    # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
    # into a HalfTensor
    if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
        self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
    else:
        logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
    return self

For transformers>=4.48.0, the code of transformers.BatchEncoding.to in src/transformers/tokenization_utils_base.py:

def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding":
    """
    Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).

    Args:
        device (`str` or `torch.device`): The device to put the tensors on.
        non_blocking (`bool`): Whether to perform the copy asynchronously.

    Returns:
        [`BatchEncoding`]: The same instance after modification.
    """
    requires_backends(self, ["torch"])
    import torch

    # This check catches things like APEX blindly calling "to" on all inputs to a module
    # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
    # into a HalfTensor
    if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
        self.data = {
            k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v
            for k, v in self.data.items()
        }
    else:
        logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
    return self

Thus NEED to add non_blocking for ListWrapper.to function to support transformers>=4.48.0

@AgentDS AgentDS changed the title add non_blocking args to support transformers>=4.48.0 add non_blocking=False args to support transformers>=4.48.0 Oct 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant