Skip to content

Commit

Permalink
pubsub/rabbitpubsub: add query string to set the routing key from met…
Browse files Browse the repository at this point in the history
…adata (#3433)
  • Loading branch information
peczenyj authored May 31, 2024
1 parent e677ded commit 0866b65
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 29 deletions.
4 changes: 3 additions & 1 deletion pubsub/kafkapubsub/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ const Scheme = "kafka"
// URLOpener opens Kafka URLs like "kafka://mytopic" for topics and
// "kafka://group?topic=mytopic" for subscriptions.
//
// For topics, the URL's host+path is used as the topic name.
// For topics, the URL's host+path is used as the topic name,
// and the "key_name" query parameter is used to extract the routing key
// from metadata.
//
// For subscriptions, the URL's host+path is used as the group name,
// and the "topic" query parameter(s) are used as the set of topics to
Expand Down
8 changes: 2 additions & 6 deletions pubsub/rabbitpubsub/amqp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ const (
// response. We always want to wait.
wait = false

// Always use the empty routing key. This driver expects to be used with topic
// exchanges, which disregard the routing key.
routingKey = ""

// If the message can't be enqueued, return it to the sender rather than silently
// dropping it.
mandatory = true
Expand All @@ -49,7 +45,7 @@ type amqpConnection interface {

// See https://pkg.go.dev/github.com/rabbitmq/amqp091-go#Channel for the documentation of these methods.
type amqpChannel interface {
Publish(exchange string, msg amqp.Publishing) error
Publish(exchange, routingKey string, msg amqp.Publishing) error
Consume(queue, consumer string) (<-chan amqp.Delivery, error)
Ack(tag uint64) error
Nack(tag uint64) error
Expand Down Expand Up @@ -93,7 +89,7 @@ type channel struct {
ch *amqp.Channel
}

func (ch *channel) Publish(exchange string, msg amqp.Publishing) error {
func (ch *channel) Publish(exchange, routingKey string, msg amqp.Publishing) error {
return ch.ch.Publish(exchange, routingKey, mandatory, immediate, msg)
}

Expand Down
9 changes: 5 additions & 4 deletions pubsub/rabbitpubsub/fake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (ch *fakeChannel) QueueDeclareAndBind(queueName, exchangeName string) error
return nil
}

func (ch *fakeChannel) Publish(exchangeName string, pub amqp.Publishing) error {
func (ch *fakeChannel) Publish(exchangeName, routingKey string, pub amqp.Publishing) error {
if ch.isClosed() {
return amqp.ErrClosed
}
Expand All @@ -168,9 +168,10 @@ func (ch *fakeChannel) Publish(exchangeName string, pub amqp.Publishing) error {
// The message is unroutable. Send a Return to all channels registered with
// NotifyReturn.
ret := amqp.Return{
Exchange: exchangeName,
ReplyCode: amqp.NoRoute,
ReplyText: "NO_ROUTE: no queues bound to exchange",
Exchange: exchangeName,
ReplyCode: amqp.NoRoute,
ReplyText: "NO_ROUTE: no queues bound to exchange",
RoutingKey: routingKey,
}
for _, c := range ch.returnChans {
select {
Expand Down
69 changes: 54 additions & 15 deletions pubsub/rabbitpubsub/rabbit.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,22 @@ type URLOpener struct {

// OpenTopicURL opens a pubsub.Topic based on u.
func (o *URLOpener) OpenTopicURL(ctx context.Context, u *url.URL) (*pubsub.Topic, error) {
for param := range u.Query() {
return nil, fmt.Errorf("open topic %v: invalid query parameter %q", u, param)
opts := o.TopicOptions
for param, value := range u.Query() {
switch param {
case "key_name":
if len(value) != 1 || len(value[0]) == 0 {
return nil, fmt.Errorf("open topic %v: invalid query parameter %q", u, param)
}

opts.KeyName = value[0]
default:
return nil, fmt.Errorf("open topic %v: invalid query parameter %q", u, param)
}
}

exchangeName := path.Join(u.Host, u.Path)
return OpenTopic(o.Connection, exchangeName, &o.TopicOptions), nil
return OpenTopic(o.Connection, exchangeName, &opts), nil
}

// OpenSubscriptionURL opens a pubsub.Subscription based on u.
Expand Down Expand Up @@ -147,6 +158,7 @@ func (o *URLOpener) OpenSubscriptionURL(ctx context.Context, u *url.URL) (*pubsu
type topic struct {
exchange string // the AMQP exchange
conn amqpConnection
opts *TopicOptions

mu sync.Mutex
ch amqpChannel // AMQP channel used for all communication.
Expand All @@ -157,11 +169,22 @@ type topic struct {

// TopicOptions sets options for constructing a *pubsub.Topic backed by
// RabbitMQ.
type TopicOptions struct{}
type TopicOptions struct {
// KeyName optionally sets the Message.Metadata key to use as the optional
// RabbitMQ message key. If set, and if a matching Message.Metadata key is found,
// the value for that key will be used as the routing key when sending to
// RabbitMQ, instead of being added to the message headers.
KeyName string
}

// SubscriptionOptions sets options for constructing a *pubsub.Subscription
// backed by RabbitMQ.
type SubscriptionOptions struct {
// KeyName optionally sets the Message.Metadata key in which to store the
// RabbitMQ message key. If set, and if the RabbitMQ message key is non-empty,
// the key value will be stored in Message.Metadata under KeyName.
KeyName string

// Qos property prefetch count. Optional.
PrefetchCount *int
}
Expand All @@ -181,13 +204,18 @@ type SubscriptionOptions struct {
// The documentation of the amqp package recommends using separate connections for
// publishing and subscribing.
func OpenTopic(conn *amqp.Connection, name string, opts *TopicOptions) *pubsub.Topic {
return pubsub.NewTopic(newTopic(&connection{conn}, name), nil)
return pubsub.NewTopic(newTopic(&connection{conn}, name, opts), nil)
}

func newTopic(conn amqpConnection, name string) *topic {
func newTopic(conn amqpConnection, name string, opts *TopicOptions) *topic {
if opts == nil {
opts = &TopicOptions{}
}

return &topic{
conn: conn,
exchange: name,
opts: opts,
}
}

Expand Down Expand Up @@ -271,7 +299,7 @@ func (t *topic) SendBatch(ctx context.Context, ms []*driver.Message) error {

var perr error
for _, m := range ms {
pub := toPublishing(m)
routingKey, pub := toRoutingKeyAndAMQPPublishing(m, t.opts)
if m.BeforeSend != nil {
asFunc := func(i interface{}) bool {
if p, ok := i.(**amqp.Publishing); ok {
Expand All @@ -284,7 +312,7 @@ func (t *topic) SendBatch(ctx context.Context, ms []*driver.Message) error {
return err
}
}
if perr = ch.Publish(t.exchange, pub); perr != nil {
if perr = ch.Publish(t.exchange, routingKey, pub); perr != nil {
cancel()
break
}
Expand Down Expand Up @@ -410,16 +438,23 @@ func closeErr(closec <-chan *amqp.Error) error {
}
}

// toPublishing converts a driver.Message to an amqp.Publishing.
func toPublishing(m *driver.Message) amqp.Publishing {
// toRoutingKeyAndAMQPPublishing converts a driver.Message to a pair routingKey + amqp.Publishing.
func toRoutingKeyAndAMQPPublishing(m *driver.Message, opts *TopicOptions) (routingKey string, msg amqp.Publishing) {
h := amqp.Table{}
for k, v := range m.Metadata {
h[k] = v
if opts.KeyName == k {
routingKey = v
} else {
h[k] = v
}
}
return amqp.Publishing{

msg = amqp.Publishing{
Headers: h,
Body: m.Body,
}

return routingKey, msg
}

// IsRetryable implements driver.Topic.IsRetryable.
Expand Down Expand Up @@ -665,7 +700,7 @@ func (s *subscription) ReceiveBatch(ctx context.Context, maxMessages int) ([]*dr
// error.
return nil, errors.New("rabbitpubsub: delivery channel closed unexpectedly")
}
ms = append(ms, toMessage(d))
ms = append(ms, toDriverMessage(d, s.opts))
if len(ms) >= maxMessages {
return ms, nil
}
Expand All @@ -679,14 +714,18 @@ func (s *subscription) ReceiveBatch(ctx context.Context, maxMessages int) ([]*dr
}
}

// toMessage converts an amqp.Delivery (a received message) to a driver.Message.
func toMessage(d amqp.Delivery) *driver.Message {
// toDriverMessage converts an amqp.Delivery (a received message) to a driver.Message.
func toDriverMessage(d amqp.Delivery, opts *SubscriptionOptions) *driver.Message {
// Delivery.Headers is a map[string]interface{}, so we have to
// convert each value to a string.
md := map[string]string{}
for k, v := range d.Headers {
md[k] = fmt.Sprint(v)
}
// Add a metadata entry for the message routing key if appropriate.
if d.RoutingKey != "" && opts.KeyName != "" {
md[opts.KeyName] = d.RoutingKey
}
loggableID := d.MessageId
if loggableID == "" {
loggableID = d.CorrelationId
Expand Down
8 changes: 5 additions & 3 deletions pubsub/rabbitpubsub/rabbit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ func (h *harness) CreateTopic(_ context.Context, testName string) (dt driver.Top
}
ch.ExchangeDelete(exchange)
}
return newTopic(h.conn, exchange), cleanup, nil
return newTopic(h.conn, exchange, nil), cleanup, nil
}

func (h *harness) MakeNonexistentTopic(context.Context) (driver.Topic, error) {
return newTopic(h.conn, "nonexistent-topic"), nil
return newTopic(h.conn, "nonexistent-topic", nil), nil
}

func (h *harness) CreateSubscription(_ context.Context, dt driver.Topic, testName string) (ds driver.Subscription, cleanup func(), err error) {
Expand Down Expand Up @@ -170,7 +170,7 @@ func TestUnroutable(t *testing.T) {
if err := declareExchange(conn, "u"); err != nil {
t.Fatal(err)
}
topic := newTopic(conn, "u")
topic := newTopic(conn, "u", nil)
msgs := []*driver.Message{
{Body: []byte("")},
{Body: []byte("")},
Expand Down Expand Up @@ -394,7 +394,9 @@ func TestOpenTopicFromURL(t *testing.T) {
WantErr bool
}{
{"valid url", "rabbit://%s", false},
{"valid url with key name parameter", "rabbit://%s?key_name=foo", false},
{"invalid url with parameters", "rabbit://%s?param=value", true},
{"invalid url with key name parameter", "rabbit://%s?key_name=", true},
}

for _, test := range tests {
Expand Down

0 comments on commit 0866b65

Please sign in to comment.