Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable sharing pytorch modules #1695

Merged
merged 1 commit into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,19 @@ void bind_core(py::module& mod) {
return py::cast(self->GetMember(key));
},
doc::ObjectMeta_get_member)
.def(
"member" /* alias for get_member() */,
[](ObjectMeta* self, std::string const& key) -> py::object {
auto const& tree = self->MetaData();
auto iter = tree.find(key);
if (iter == tree.end()) {
return py::none();
}
VINEYARD_ASSERT(iter->is_object() && !iter->empty(),
"The value is not a member, but a meta");
return py::cast(self->GetMember(key));
},
doc::ObjectMeta_get_member)
.def("get_buffer",
[](ObjectMeta* self, const ObjectID key) -> py::memoryview {
std::shared_ptr<Buffer> buffer;
Expand Down
38 changes: 38 additions & 0 deletions python/vineyard/contrib/ml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and inference tasks in these frameworks.
Examples
--------

### Datasets

The following examples shows how `DataFrame` in vineyard can be used as the input
of Dataset for PyTorch:

Expand Down Expand Up @@ -49,6 +51,42 @@ for data, label in pipe:
pass
```

### Pytorch Modules

The following example shows how to use vineyard to share pytorch modules between processes:

```python
import torch
import vineyard

# connected to vineyard, see also: https://v6d.io/notes/getting-started.html
client = vineyard.connect(os.environ['VINEYARD_IPC_SOCKET'])

# generate a dummy model in vineyard
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

model = Model()

# put the model into vineyard
from vineyard.contrib.ml.torch import torch_context
with torch_context():
object_id = client.put(model)

# get the module state dict from vineyard and load it into a new model
model = Model()
with torch_context():
state_dict = client.get(object_id)
model.load_state_dict(state_dict, assign=True)
```

Reference and Implementation
----------------------------

Expand Down
45 changes: 45 additions & 0 deletions python/vineyard/contrib/ml/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
# limitations under the License.
#

import copy
import itertools
from typing import Any
from typing import Dict

import numpy as np
import pandas as pd
import pyarrow as pa
Expand All @@ -30,6 +35,8 @@
from vineyard.data.dataframe import NDArrayArray

torch = lazy_import.lazy_module("torch")
nn = lazy_import.lazy_module("torch.nn")
F = lazy_import.lazy_module("torch.nn.functional")
torchdata = lazy_import.lazy_module("torchdata")


Expand Down Expand Up @@ -130,3 +137,41 @@ def test_torch_dataset_table(vineyard_client):
assert torch.isclose(
value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
).all()


class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))


def assert_torch_module_equal(model1, model2):
assert isinstance(model1, nn.Module)
assert isinstance(model2, nn.Module)
assert len(list(model1.parameters())) == len(list(model2.parameters()))
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.allclose(p1, p2), f'{p1} != {p2}'


@pytest_cases.parametrize(
"vineyard_client,model",
itertools.product(
[vineyard_client, vineyard_rpc_client],
[nn.Linear(5, 2), nn.Conv2d(1, 20, 5), Model()],
),
)
def test_torch_module(vineyard_client, model):
object_id = vineyard_client.put(model)
value: Dict[str, Any] = vineyard_client.get(object_id)

result = copy.deepcopy(model)
result.to(torch.device('meta'))
result.load_state_dict(value, assign=True)

# check the module's equality
assert_torch_module_equal(model, result)
102 changes: 84 additions & 18 deletions python/vineyard/contrib/ml/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@

from vineyard._C import ObjectMeta
from vineyard.core import context
from vineyard.data.utils import from_json
from vineyard.data.utils import to_json

torch = lazy_import.lazy_module("torch")
torchdata = lazy_import.lazy_module("torchdata")


class WholeBatchSampler(torch.utils.data.Sampler[List[int]]):
Expand Down Expand Up @@ -137,25 +137,9 @@ def torch_global_dataframe_resolver(obj, resolver, **_kw):
return torch.utils.data.ConcatDataset(data)


def register_torch_types(builder_ctx, resolver_ctx):
if builder_ctx is not None:
builder_ctx.register(torch.Tensor, torch_tensor_builder)
builder_ctx.register(torch.utils.data.Dataset, torch_dataset_builder)

if resolver_ctx is not None:
resolver_ctx.register('vineyard::Tensor', torch_tensor_resolver)
resolver_ctx.register('vineyard::DataFrame', torch_dataset_resolver)
resolver_ctx.register('vineyard::RecordBatch', torch_dataset_resolver)
resolver_ctx.register('vineyard::Table', torch_dataset_resolver)
resolver_ctx.register('vineyard::GlobalTensor', torch_global_tensor_resolver)
resolver_ctx.register(
'vineyard::GlobalDataFrame', torch_global_dataframe_resolver
)


def datapipe(
dataset: torch.utils.data.Dataset,
) -> torchdata.datapipes.iter.IterableWrapper:
): # -> "torchdata.datapipes.iter.IterableWrapper":
'''Convert a torch.utils.data.Dataset to a torchdata.datapipes.iter.IterableWrapper.

e.g.,
Expand All @@ -182,9 +166,91 @@ def datapipe(
Returns:
A torchdata.datapipes.iter.IterableWrapper.
'''
import torchdata

return torchdata.datapipes.iter.IterableWrapper(dataset)


def torch_module_builder(client, value, builder, **kw):
def go(state_dict, key_prefix, tensors):
if isinstance(state_dict, torch.Tensor):
r = builder.run(client, state_dict, **kw)
tensors[key_prefix] = r
if isinstance(r, ObjectMeta):
r = r.id
return r
elif isinstance(state_dict, dict):
keys = list(state_dict.keys())
for key in keys:
state_dict[key] = go(state_dict[key], f'{key_prefix}.{key}', tensors)
return state_dict
elif isinstance(state_dict, (tuple, list)):
return [
go(element, f'{key_prefix}.{i}', tensors)
for i, element in enumerate(state_dict)
]
else:
return state_dict

if isinstance(value, torch.nn.Module):
value = value.state_dict()

tensors = dict()
value = go(value, 'tensor', tensors)

meta = ObjectMeta()
meta['typename'] = 'vineyard::torch::Module'
meta['state_dict'] = to_json(value)
for key, tensor in tensors.items():
meta.add_member(key, tensor)
return client.create_metadata(meta)


def torch_module_resolver(obj, resolver, **kw):
def go(state_dict, key_prefix, tensors):
if key_prefix in tensors:
return tensors[key_prefix]
elif isinstance(state_dict, dict):
keys = list(state_dict.keys())
for key in keys:
state_dict[key] = go(state_dict[key], f'{key_prefix}.{key}', tensors)
return state_dict
elif isinstance(state_dict, (tuple, list)):
return [
go(element, f'{key_prefix}.{i}', tensors)
for i, element in enumerate(state_dict)
]
else:
return state_dict

meta = obj.meta
state_dict = from_json(meta['state_dict'])
tensors = dict()
for key, value in meta.items():
if key.startswith('tensor.'):
tensors[key] = resolver.run(value, **kw)
state_dict = go(state_dict, 'tensor', tensors)
return state_dict


def register_torch_types(builder_ctx, resolver_ctx):
if builder_ctx is not None:
builder_ctx.register(torch.Tensor, torch_tensor_builder)
builder_ctx.register(torch.utils.data.Dataset, torch_dataset_builder)
builder_ctx.register(torch.nn.Module, torch_module_builder)

if resolver_ctx is not None:
resolver_ctx.register('vineyard::Tensor', torch_tensor_resolver)
resolver_ctx.register('vineyard::DataFrame', torch_dataset_resolver)
resolver_ctx.register('vineyard::RecordBatch', torch_dataset_resolver)
resolver_ctx.register('vineyard::Table', torch_dataset_resolver)
resolver_ctx.register('vineyard::GlobalTensor', torch_global_tensor_resolver)
resolver_ctx.register(
'vineyard::GlobalDataFrame', torch_global_dataframe_resolver
)
resolver_ctx.register('vineyard::torch::Module', torch_module_resolver)


@contextlib.contextmanager
def torch_context():
with context() as (builder_ctx, resolver_ctx):
Expand Down
3 changes: 3 additions & 0 deletions python/vineyard/core/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(self, parent_context: Optional["ResolverContext"] = None):
def __str__(self) -> str:
return str(self._factory)

def __repr__(self) -> str:
return repr(self._factory)

@property
def parent_context(self) -> "ResolverContext":
return self._parent_context
Expand Down
2 changes: 2 additions & 0 deletions python/vineyard/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def build_numpy_buffer(client, array):
def default_json_encoder(value):
if isinstance(value, (np.integer, np.floating)):
return value.item()
if isinstance(value, ObjectID):
return int(value)
raise TypeError


Expand Down