Skip to content

Commit b2ff1aa

Browse files
pieternfacebook-github-bot
authored andcommitted
Add sequence number to gloo::transport::tcp::Address
Summary: Pull Request resolved: pytorch#239 Differential Revision: D18909096 Test Plan: Imported from OSS Pulled By: pietern fbshipit-source-id: f0e22f24d7cdd0120f8265268a7ad0e20bcd1109
1 parent b78aa2c commit b2ff1aa

File tree

5 files changed

+73
-35
lines changed

5 files changed

+73
-35
lines changed

gloo/rendezvous/context.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "gloo/common/error.h"
1515
#include "gloo/context.h"
1616
#include "gloo/rendezvous/store.h"
17+
#include "gloo/transport/address.h"
1718
#include "gloo/transport/device.h"
1819

1920
namespace gloo {
@@ -38,8 +39,8 @@ class Context : public ::gloo::Context {
3839

3940
class ContextFactory {
4041
public:
41-
// Assume a pair's address is no bigger than 128 bytes
42-
static constexpr auto kMaxAddressSize = 128;
42+
static constexpr auto kMaxAddressSize =
43+
::gloo::transport::Address::kMaxByteSize;
4344

4445
explicit ContextFactory(std::shared_ptr<::gloo::Context> backingContext);
4546

gloo/transport/address.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ namespace transport {
1616

1717
class Address {
1818
public:
19+
// Upper bound for an address' byte representation.
20+
static constexpr auto kMaxByteSize = 192;
21+
1922
virtual ~Address() = 0;
2023

2124
virtual std::string str() const = 0;

gloo/transport/tcp/address.cc

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,53 @@ namespace gloo {
1717
namespace transport {
1818
namespace tcp {
1919

20-
Address::Address(const struct sockaddr_storage& ss) {
21-
ss_ = ss;
20+
Address::Address(struct sockaddr_storage ss, sequence_number_t seq) {
21+
impl_.ss = std::move(ss);
22+
impl_.seq = seq;
2223
}
2324

2425
Address::Address(const struct sockaddr* addr, size_t addrlen) {
25-
memcpy(&ss_, addr, addrlen);
26+
memcpy(&impl_.ss, addr, addrlen);
2627
}
2728

2829
Address::Address(const std::vector<char>& bytes) {
29-
GLOO_ENFORCE_EQ(sizeof(ss_), bytes.size());
30-
memcpy(&ss_, bytes.data(), sizeof(ss_));
30+
GLOO_ENFORCE_EQ(sizeof(impl_), bytes.size());
31+
memcpy(&impl_, bytes.data(), sizeof(impl_));
3132
}
3233

3334
std::vector<char> Address::bytes() const {
34-
std::vector<char> bytes(sizeof(ss_));
35-
memcpy(bytes.data(), &ss_, sizeof(ss_));
35+
std::vector<char> bytes(sizeof(impl_));
36+
memcpy(bytes.data(), &impl_, sizeof(impl_));
3637
return bytes;
3738
}
3839

3940
std::string Address::str() const {
40-
char str[INET6_ADDRSTRLEN + 8];
41+
char str[INET6_ADDRSTRLEN + 128];
4142
int port = 0;
4243

4344
str[0] = '[';
44-
if (ss_.ss_family == AF_INET) {
45-
struct sockaddr_in* in = (struct sockaddr_in*)&ss_;
45+
if (impl_.ss.ss_family == AF_INET) {
46+
struct sockaddr_in* in = (struct sockaddr_in*)&impl_.ss;
4647
inet_ntop(AF_INET, &in->sin_addr, str + 1, sizeof(str) - 1);
4748
port = in->sin_port;
48-
} else if (ss_.ss_family == AF_INET6) {
49-
struct sockaddr_in6* in6 = (struct sockaddr_in6*)&ss_;
49+
} else if (impl_.ss.ss_family == AF_INET6) {
50+
struct sockaddr_in6* in6 = (struct sockaddr_in6*)&impl_.ss;
5051
inet_ntop(AF_INET6, &in6->sin6_addr, str + 1, sizeof(str) - 1);
5152
port = in6->sin6_port;
5253
} else {
5354
snprintf(str + 1, sizeof(str) - 1, "none");
5455
}
5556

56-
auto len = strlen(str);
57+
size_t len = strlen(str);
5758
if (port > 0) {
58-
snprintf(str + len, sizeof(str) - len, "]:%d", port);
59+
len += snprintf(str + len, sizeof(str) - len, "]:%d", port);
5960
} else {
60-
snprintf(str + len, sizeof(str) - len, "]");
61+
len += snprintf(str + len, sizeof(str) - len, "]");
62+
}
63+
64+
// Append sequence number if one is set.
65+
if (impl_.seq != kSequenceNumberUnset) {
66+
len += snprintf(str + len, sizeof(str) - len, "$%ld", impl_.seq);
6167
}
6268

6369
return str;

gloo/transport/tcp/address.h

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,63 @@
88

99
#pragma once
1010

11-
#include <string>
12-
1311
#include <sys/socket.h>
12+
#include <unistd.h>
1413

1514
#include "gloo/transport/address.h"
1615

1716
namespace gloo {
1817
namespace transport {
1918
namespace tcp {
2019

21-
// Forward declaration
22-
class Pair;
20+
using sequence_number_t = ssize_t;
2321

2422
class Address : public ::gloo::transport::Address {
2523
public:
24+
static constexpr sequence_number_t kSequenceNumberUnset = -1;
25+
2626
Address() {}
27-
explicit Address(const struct sockaddr_storage&);
27+
28+
explicit Address(struct sockaddr_storage ss, sequence_number_t seq = -1);
29+
2830
explicit Address(const struct sockaddr* addr, size_t addrlen);
31+
2932
explicit Address(const std::vector<char>&);
30-
virtual ~Address() {}
3133

3234
virtual std::vector<char> bytes() const override;
35+
3336
virtual std::string str() const override;
3437

38+
const struct sockaddr_storage& getSockaddr() const {
39+
return impl_.ss;
40+
}
41+
42+
sequence_number_t getSeq() const {
43+
return impl_.seq;
44+
}
45+
3546
static Address fromSockName(int fd);
47+
3648
static Address fromPeerName(int fd);
3749

3850
protected:
39-
struct sockaddr_storage ss_;
51+
// Encapsulate fields such that it is trivially copyable. This class
52+
// is not trivially copyable itself.
53+
struct Impl {
54+
// IP address of the listening socket.
55+
struct sockaddr_storage ss;
56+
57+
// Sequence number of this address.
58+
// If this is equal to -1, the address is assumed to
59+
// represent the listening socket of a device. The sequence number
60+
// must be set before it can be used by a pair.
61+
sequence_number_t seq{kSequenceNumberUnset};
62+
};
63+
64+
static_assert(std::is_trivially_copyable<Impl>::value, "!");
65+
static_assert(sizeof(Impl) <= kMaxByteSize, "!");
4066

41-
// Pair can access ss_ directly
42-
friend class Pair;
67+
Impl impl_;
4368
};
4469

4570
} // namespace tcp

gloo/transport/tcp/pair.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,22 +202,25 @@ void Pair::connect(const Address& peer) {
202202

203203
peer_ = peer;
204204

205+
const auto& selfAddr = self_.getSockaddr();
206+
const auto& peerAddr = peer_.getSockaddr();
207+
205208
// Addresses have to have same family
206-
if (self_.ss_.ss_family != peer_.ss_.ss_family) {
209+
if (selfAddr.ss_family != peerAddr.ss_family) {
207210
GLOO_THROW_INVALID_OPERATION_EXCEPTION("address family mismatch");
208211
}
209212

210-
if (self_.ss_.ss_family == AF_INET) {
211-
struct sockaddr_in* sa = (struct sockaddr_in*)&self_.ss_;
212-
struct sockaddr_in* sb = (struct sockaddr_in*)&peer_.ss_;
213+
if (selfAddr.ss_family == AF_INET) {
214+
struct sockaddr_in* sa = (struct sockaddr_in*)&selfAddr;
215+
struct sockaddr_in* sb = (struct sockaddr_in*)&peerAddr;
213216
addrlen = sizeof(struct sockaddr_in);
214217
rv = memcmp(&sa->sin_addr, &sb->sin_addr, sizeof(struct in_addr));
215218
if (rv == 0) {
216219
rv = sa->sin_port - sb->sin_port;
217220
}
218-
} else if (peer_.ss_.ss_family == AF_INET6) {
219-
struct sockaddr_in6* sa = (struct sockaddr_in6*)&self_.ss_;
220-
struct sockaddr_in6* sb = (struct sockaddr_in6*)&peer_.ss_;
221+
} else if (peerAddr.ss_family == AF_INET6) {
222+
struct sockaddr_in6* sa = (struct sockaddr_in6*)&selfAddr;
223+
struct sockaddr_in6* sb = (struct sockaddr_in6*)&peerAddr;
221224
addrlen = sizeof(struct sockaddr_in6);
222225
rv = memcmp(&sa->sin6_addr, &sb->sin6_addr, sizeof(struct in6_addr));
223226
if (rv == 0) {
@@ -243,7 +246,7 @@ void Pair::connect(const Address& peer) {
243246
::close(fd_);
244247

245248
// Create new socket to connect to peer.
246-
fd_ = socket(peer_.ss_.ss_family, SOCK_STREAM | SOCK_NONBLOCK, 0);
249+
fd_ = socket(peerAddr.ss_family, SOCK_STREAM | SOCK_NONBLOCK, 0);
247250
if (fd_ == -1) {
248251
signalAndThrowException(GLOO_ERROR_MSG("socket: ", strerror(errno)));
249252
}
@@ -258,7 +261,7 @@ void Pair::connect(const Address& peer) {
258261
}
259262

260263
// Connect to peer
261-
rv = ::connect(fd_, (struct sockaddr*)&peer_.ss_, addrlen);
264+
rv = ::connect(fd_, (struct sockaddr*)&peerAddr, addrlen);
262265
if (rv == -1 && errno != EINPROGRESS) {
263266
::close(fd_);
264267
fd_ = FD_INVALID;

0 commit comments

Comments
 (0)