Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.40】为 Paddle 代码转换工具新增 API 转换规则(第 7 组)-Part #6920

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
## [组合替代实现] torch.cuda.comm.gather

### [torch.cuda.comm.gather](https://pytorch.org/docs/stable/generated/torch.cuda.comm.gather.html)
```python
torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)
```

将多个设备的张量集中起来,Paddle 无此 API,需要组合替代实现。

### 转写示例
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同步提交Matcher并且测试一下吧

```python
# PyTorch 写法
destination = 'cuda:0'
gathered_tensor = torch.cuda.comm.gather(tensors, destination=destination)

# Paddle 写法
def paddle_comm_gather(tensors, dim=0, destination=None, *, out=None):
if destination is None:
destination = paddle.CPUPlace()
elif 'cuda' in destination:
destination = paddle.CUDAPlace(int(destination.split(':')[-1]))

gathered_tensors = [t.cuda(destination) if 'cuda' in t.place.__str__() else t.cpu() for t in tensors]

gathered_tensor = paddle.concat(gathered_tensors, axis=dim)

if out is not None:
out.copy_(gathered_tensor)
return out

return gathered_tensor

destination = 'gpu:0'
gathered_tensor = paddle_comm_gather(tensors, dim=dim, destination=destination)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
## [组合替代实现] torch.cuda.comm.scatter

### [torch.cuda.comm.scatter](https://pytorch.org/docs/stable/generated/torch.cuda.comm.scatter.html)

```python
torch.cuda.comm.scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None)
```

将张量分散到多个设备上,Paddle 无此 API,需要组合替代实现

### 转写示例
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同步提交Matcher并且测试一下吧

```python
# torch 写法
devices = [torch.device('cuda:0'), torch.device('cuda:1')]
torch.cuda.comm.scatter(inputs, devices=devices)

# paddle 写法
def paddle_comm_scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, out=None):
if devices is None:
devices = ['cpu'] * len(tensor)

if chunk_sizes is not None:
chunks = paddle.split(tensor, num_or_sections=chunk_sizes, dim=dim)
else:
chunks = tensor if isinstance(tensor, list) else [tensor]

scattered_tensors = out if out is not None else []

for idx, (chunk, device) in enumerate(zip(chunks, devices)):
place = paddle.CUDAPlace(int(device.split(':')[-1])) if 'cuda' in device else paddle.CPUPlace()

tensor_on_device = chunk.cuda(place) if 'cuda' in device else chunk.cpu()

if streams is not None:
stream = streams[idx]
tensor_on_device = tensor_on_device.cuda(place, non_blocking=True)
tensor_on_device = tensor_on_device.cuda_stream(stream)

if out is not None:
out[idx].copy_(tensor_on_device)
else:
scattered_tensors.append(tensor_on_device)

if out is None:
return scattered_tensors

devices = ['gpu:0', 'gpu:1']
chunk_sizes = [5, 5]
scattered_tensors = paddle_comm_scatter(tensor, devices=devices, chunk_sizes=chunk_sizes)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## [组合替代实现] torch.cuda.device_of

### [torch.cuda.device_of](https://pytorch.org/docs/stable/generated/torch.cuda.device_of.html#torch.cuda.device_of)
```python
torch.cuda.device_of(obj)
```

获取张量所在的设备,Paddle 无此 api,需要组合实现
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/stable/generated/torch.cuda.device_of.html#torch.cuda.device_of 这个好像不是这个功能,开发Matcher并测试一下吧

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

可以通过`tensor.place`来获取张量所在的设备信息

### 转写示例
```python
# torch 写法
device = torch.cuda.device_of(tensor)

# paddle 写法
device = tensor.place
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
## [ 组合替代实现 ]torch.cuda.is_initialized

### [torch.cuda.is_initialized](xly.bce.baidu.com/paddlepaddle/fluid-doc/newipipe/detail/11746629/job/27824342/realTimeLog/479)

```python
torch.cuda.is_initialized()
```

判断 cuda 是否初始化,Paddle 无此 API,需要组合实现。
Paddle 可以通过检查是否支持 cuda,并且尝试创建一个张量来判断初始化是否成功。

### 转写示例

```python
# torch 写法
torch.cuda.is_initialized()

# paddle 写法
def paddle_cuda_is_initialized():
if not paddle.is_compiled_with_cuda():
return False
try:
cuda_tensor = paddle.rand([1], place=paddle.CUDAPlace(0))
return True
except Exception as e:
return False
paddle_cuda_is_initialized()
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## [组合替代实现] torch.get_default_device

### [torch.get_default_device](https://pytorch.org/docs/stable/generated/torch.get_default_device.html)
```python
torch.get_default_device()
```

获取默认的设备,Paddle 无此 api, 需要组合实现

### 转写示例
```python
# torch 写法
device = torch.get_default_device()

# paddle 写法
device = paddle.device.get_device()
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## [组合替代实现] torch.set_default_device
### [torch.set_default_device](https://pytorch.org/docs/stable/generated/torch.set_default_device.html#torch.set_default_device)
```python
torch.set_default_device(device)
```

设置默认设备,Paddle 无此 api,需要组合替代实现。

### 转写示例

```python
# torch 写法
torch.set_default_device(device)

# paddle 写法
paddle.device.set_device(device)
```