diff --git a/main.go b/main.go index ee55f7e..3dade73 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "database/sql" "encoding/json" "fmt" @@ -57,20 +58,46 @@ func init() { topic = os.Getenv("KAFKA_TOPIC") } -func consumeFromKafka(brokerList []string, topic string, db *sql.DB) error { - consumer, err := sarama.NewConsumer(brokerList, nil) +func consumeFromKafka(brokerList []string, topic string, groupID string, db *sql.DB) error { + config := sarama.NewConfig() + config.Consumer.Group.Rebalance.Strategy = sarama.BalanceStrategyRoundRobin + config.Version = sarama.V2_1_0_0 // ensure the version is compatible with your Kafka + + client, err := sarama.NewClient(brokerList, config) if err != nil { - return err + return fmt.Errorf("error creating kafka client: %v", err) } - defer consumer.Close() + defer client.Close() - partitionConsumer, err := consumer.ConsumePartition(topic, 0, sarama.OffsetNewest) + consumerGroup, err := sarama.NewConsumerGroupFromClient(groupID, client) if err != nil { - return err + return fmt.Errorf("error creating consumer group: %v", err) + } + defer consumerGroup.Close() + + ctx := context.Background() + for { + if err := consumerGroup.Consume(ctx, []string{topic}, &consumerGroupHandler{db: db}); err != nil { + log.Printf("Error from consumer: %v", err) + return err + } } - defer partitionConsumer.Close() +} + +type consumerGroupHandler struct { + db *sql.DB +} - for message := range partitionConsumer.Messages() { +func (h *consumerGroupHandler) Setup(_ sarama.ConsumerGroupSession) error { + return nil +} + +func (h *consumerGroupHandler) Cleanup(_ sarama.ConsumerGroupSession) error { + return nil +} + +func (h *consumerGroupHandler) ConsumeClaim(sess sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { + for message := range claim.Messages() { var cve CVERecord err := json.Unmarshal(message.Value, &cve) if err != nil { @@ -78,12 +105,12 @@ func consumeFromKafka(brokerList []string, topic string, db *sql.DB) error { continue } - err = storeInDatabase(db, cve) + err = storeInDatabase(h.db, cve) if err != nil { log.Printf("Failed to store CVE record in database: %v", err) } + sess.MarkMessage(message, "") } - return nil } @@ -163,8 +190,9 @@ func main() { } defer db.Close() + groupID := "cve-consumer-group" // Consistent group ID across all instances go func() { - err = consumeFromKafka(brokerList, topic, db) + err = consumeFromKafka(brokerList, topic, groupID, db) if err != nil { log.Printf("Failed to consume from Kafka: %v", err) }