Skip to content

Commit 9171cb6

Browse files
First round of refactor
1 parent 8c8175b commit 9171cb6

File tree

11 files changed

+652
-46
lines changed

11 files changed

+652
-46
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package assertions
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/codecrafters-io/kafka-tester/protocol/kafkaapi"
7+
"github.com/codecrafters-io/tester-utils/logger"
8+
)
9+
10+
var apiKeyNames = map[int16]string{
11+
1: "FETCH",
12+
18: "API_VERSIONS",
13+
75: "DESCRIBE_TOPIC_PARTITIONS",
14+
}
15+
16+
var errorCodes = map[int]string{
17+
0: "NO_ERROR",
18+
3: "UNKNOWN_TOPIC_OR_PARTITION",
19+
35: "UNSUPPORTED_VERSION",
20+
100: "UNKNOWN_TOPIC_ID",
21+
}
22+
23+
type ApiVersionsResponseAssertion struct {
24+
ActualValue kafkaapi.ApiVersionsResponse
25+
ExpectedValue kafkaapi.ApiVersionsResponse
26+
}
27+
28+
func NewApiVersionsResponseAssertion(actualValue kafkaapi.ApiVersionsResponse, expectedValue kafkaapi.ApiVersionsResponse) *ApiVersionsResponseAssertion {
29+
return &ApiVersionsResponseAssertion{
30+
ActualValue: actualValue,
31+
ExpectedValue: expectedValue,
32+
}
33+
}
34+
35+
func (a *ApiVersionsResponseAssertion) assertBody(logger *logger.Logger) error {
36+
expectedErrorCodeName, ok := errorCodes[int(a.ExpectedValue.Body.ErrorCode)]
37+
if !ok {
38+
panic(fmt.Sprintf("CodeCrafters Internal Error: Expected %d to be in errorCodes map", a.ExpectedValue.Body.ErrorCode))
39+
}
40+
if a.ActualValue.Body.ErrorCode != a.ExpectedValue.Body.ErrorCode {
41+
return fmt.Errorf("Expected ErrorCode to be %d (%s), got %d", a.ExpectedValue.Body.ErrorCode, expectedErrorCodeName, a.ActualValue.Body.ErrorCode)
42+
}
43+
logger.Successf("✓ ErrorCode: %d (%s)", a.ActualValue.Body.ErrorCode, expectedErrorCodeName)
44+
45+
if err := a.assertAPIKeysArray(logger); err != nil {
46+
return err
47+
}
48+
49+
return nil
50+
}
51+
52+
func (a *ApiVersionsResponseAssertion) assertAPIKeysArray(logger *logger.Logger) error {
53+
if len(a.ActualValue.Body.ApiKeys) < len(a.ExpectedValue.Body.ApiKeys) {
54+
return fmt.Errorf("Expected API keys array to include atleast %d keys, got %d", len(a.ExpectedValue.Body.ApiKeys), len(a.ActualValue.Body.ApiKeys))
55+
}
56+
logger.Successf("✓ API keys array length: %d", len(a.ActualValue.Body.ApiKeys))
57+
58+
for _, expectedApiVersionKey := range a.ExpectedValue.Body.ApiKeys {
59+
found := false
60+
61+
for _, actualApiVersionKey := range a.ActualValue.Body.ApiKeys {
62+
if actualApiVersionKey.ApiKey == expectedApiVersionKey.ApiKey {
63+
found = true
64+
if actualApiVersionKey.MinVersion > expectedApiVersionKey.MaxVersion {
65+
return fmt.Errorf("Expected min version %v to be < max version %v for %s", actualApiVersionKey.MinVersion, expectedApiVersionKey.MaxVersion, apiKeyNames[expectedApiVersionKey.ApiKey])
66+
}
67+
68+
// anything above or equal to expected minVersion is fine
69+
if actualApiVersionKey.MinVersion < expectedApiVersionKey.MinVersion {
70+
return fmt.Errorf("Expected API version %v to be supported for %s, got %v", expectedApiVersionKey.MinVersion, apiKeyNames[expectedApiVersionKey.ApiKey], actualApiVersionKey.MinVersion)
71+
}
72+
logger.Successf("✓ MinVersion for %s is <= %v & >= %v", apiKeyNames[expectedApiVersionKey.ApiKey], expectedApiVersionKey.MaxVersion, expectedApiVersionKey.MinVersion)
73+
74+
if actualApiVersionKey.MaxVersion < expectedApiVersionKey.MaxVersion {
75+
return fmt.Errorf("Expected API version %v to be supported for %s, got %v", expectedApiVersionKey.MaxVersion, apiKeyNames[expectedApiVersionKey.ApiKey], actualApiVersionKey.MaxVersion)
76+
}
77+
logger.Successf("✓ MaxVersion for %s is >= %v", apiKeyNames[expectedApiVersionKey.ApiKey], expectedApiVersionKey.MaxVersion)
78+
}
79+
}
80+
81+
if !found {
82+
return fmt.Errorf("Expected APIVersionsResponseKey array to include API key %d (%s)", expectedApiVersionKey.ApiKey, apiKeyNames[expectedApiVersionKey.ApiKey])
83+
}
84+
}
85+
86+
return nil
87+
}
88+
89+
func (a *ApiVersionsResponseAssertion) Run(logger *logger.Logger) error {
90+
if err := NewResponseHeaderAssertion(a.ActualValue.Header, a.ExpectedValue.Header).Run(logger); err != nil {
91+
return err
92+
}
93+
94+
if err := a.assertBody(logger); err != nil {
95+
return err
96+
}
97+
98+
return nil
99+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package assertions
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/codecrafters-io/kafka-tester/protocol/kafkaapi/headers"
7+
"github.com/codecrafters-io/tester-utils/logger"
8+
)
9+
10+
type ResponseHeaderAssertion struct {
11+
ActualValue headers.ResponseHeader
12+
ExpectedValue headers.ResponseHeader
13+
}
14+
15+
func NewResponseHeaderAssertion(actualValue headers.ResponseHeader, expectedValue headers.ResponseHeader) *ResponseHeaderAssertion {
16+
return &ResponseHeaderAssertion{
17+
ActualValue: actualValue,
18+
ExpectedValue: expectedValue,
19+
}
20+
}
21+
22+
func (a *ResponseHeaderAssertion) assertCorrelationId(logger *logger.Logger) error {
23+
if a.ActualValue.CorrelationId != a.ExpectedValue.CorrelationId {
24+
return fmt.Errorf("Expected correlation_id to be %d, got %d", a.ExpectedValue.CorrelationId, a.ActualValue.CorrelationId)
25+
}
26+
logger.Successf("✓ correlation_id: %v", a.ActualValue.CorrelationId)
27+
28+
return nil
29+
}
30+
31+
func (a *ResponseHeaderAssertion) Run(logger *logger.Logger) error {
32+
return a.assertCorrelationId(logger)
33+
}

internal/stage_2.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"github.com/codecrafters-io/kafka-tester/internal/kafka_executable"
77
"github.com/codecrafters-io/kafka-tester/protocol/builder"
88
"github.com/codecrafters-io/kafka-tester/protocol/decoder"
9-
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client_legacy"
9+
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client"
1010
"github.com/codecrafters-io/kafka-tester/protocol/kafkaapi"
1111
"github.com/codecrafters-io/kafka-tester/protocol/serializer_legacy"
1212
"github.com/codecrafters-io/kafka-tester/protocol/utils"
@@ -27,7 +27,7 @@ func testHardcodedCorrelationId(stageHarness *test_case_harness.TestCaseHarness)
2727
return err
2828
}
2929

30-
client := kafka_client_legacy.NewClient("localhost:9092")
30+
client := kafka_client.NewClient("localhost:9092")
3131

3232
if err := client.ConnectWithRetries(b, stageLogger); err != nil {
3333
return err

internal/stage_3.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"github.com/codecrafters-io/kafka-tester/internal/kafka_executable"
99
"github.com/codecrafters-io/kafka-tester/protocol/builder"
1010
"github.com/codecrafters-io/kafka-tester/protocol/decoder"
11-
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client_legacy"
11+
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client"
1212
"github.com/codecrafters-io/kafka-tester/protocol/kafkaapi"
1313
"github.com/codecrafters-io/kafka-tester/protocol/serializer_legacy"
1414
"github.com/codecrafters-io/kafka-tester/protocol/utils"
@@ -28,7 +28,7 @@ func testCorrelationId(stageHarness *test_case_harness.TestCaseHarness) error {
2828
return err
2929
}
3030

31-
client := kafka_client_legacy.NewClient("localhost:9092")
31+
client := kafka_client.NewClient("localhost:9092")
3232
if err := client.ConnectWithRetries(b, stageLogger); err != nil {
3333
return err
3434
}

internal/stage_4.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"github.com/codecrafters-io/kafka-tester/internal/kafka_executable"
99
"github.com/codecrafters-io/kafka-tester/protocol/builder"
1010
"github.com/codecrafters-io/kafka-tester/protocol/decoder"
11-
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client_legacy"
11+
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client"
1212
"github.com/codecrafters-io/kafka-tester/protocol/kafkaapi"
1313
"github.com/codecrafters-io/kafka-tester/protocol/serializer_legacy"
1414
"github.com/codecrafters-io/kafka-tester/protocol/utils"
@@ -30,7 +30,7 @@ func testAPIVersionErrorCase(stageHarness *test_case_harness.TestCaseHarness) er
3030
correlationId := getRandomCorrelationId()
3131
apiVersion := getInvalidAPIVersion()
3232

33-
client := kafka_client_legacy.NewClient("localhost:9092")
33+
client := kafka_client.NewClient("localhost:9092")
3434
if err := client.ConnectWithRetries(b, stageLogger); err != nil {
3535
return err
3636
}

internal/stage_5.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package internal
22

33
import (
4-
"github.com/codecrafters-io/kafka-tester/internal/assertions_legacy"
4+
"github.com/codecrafters-io/kafka-tester/internal/assertions"
55
"github.com/codecrafters-io/kafka-tester/internal/kafka_executable"
6-
"github.com/codecrafters-io/kafka-tester/protocol/builder_legacy"
7-
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client_legacy"
6+
"github.com/codecrafters-io/kafka-tester/protocol/builder"
7+
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client"
88
"github.com/codecrafters-io/kafka-tester/protocol/serializer_legacy"
99
"github.com/codecrafters-io/tester-utils/logger"
1010
"github.com/codecrafters-io/tester-utils/test_case_harness"
@@ -24,13 +24,13 @@ func testAPIVersion(stageHarness *test_case_harness.TestCaseHarness) error {
2424

2525
correlationId := getRandomCorrelationId()
2626

27-
client := kafka_client_legacy.NewClient("localhost:9092")
27+
client := kafka_client.NewClient("localhost:9092")
2828
if err := client.ConnectWithRetries(b, stageLogger); err != nil {
2929
return err
3030
}
3131
defer client.Close()
3232

33-
request := builder_legacy.NewApiVersionsRequestBuilder().
33+
request := builder.NewApiVersionsRequestBuilder().
3434
WithCorrelationId(correlationId).
3535
Build()
3636

@@ -39,17 +39,17 @@ func testAPIVersion(stageHarness *test_case_harness.TestCaseHarness) error {
3939
return err
4040
}
4141

42-
actualResponse := builder_legacy.NewApiVersionsResponseBuilder().BuildEmpty()
42+
actualResponse := builder.NewApiVersionsResponseBuilder().BuildEmpty()
4343
if err := actualResponse.Decode(rawResponse.Payload, stageLogger); err != nil {
4444
return err
4545
}
4646

47-
expectedApiVersionResponse := builder_legacy.NewApiVersionsResponseBuilder().
47+
expectedApiVersionResponse := builder.NewApiVersionsResponseBuilder().
4848
AddApiKeyEntry(18, 0, 4).
4949
WithCorrelationId(correlationId).
5050
Build()
5151

52-
if err = assertions_legacy.NewApiVersionsResponseAssertion(actualResponse, expectedApiVersionResponse).Run(stageLogger); err != nil {
52+
if err = assertions.NewApiVersionsResponseAssertion(actualResponse, expectedApiVersionResponse).Run(stageLogger); err != nil {
5353
return err
5454
}
5555

protocol/decoder/decoder.go

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
"github.com/codecrafters-io/tester-utils/logger"
1313
)
1414

15+
type tagBuffer struct{}
16+
1517
type Decoder struct {
1618
bytes []byte
1719
offset int
@@ -53,8 +55,22 @@ func (d *Decoder) unindentLog() {
5355
d.indentationLevel = max(d.indentationLevel-1, 0)
5456
}
5557

56-
func (d *Decoder) LogDecodedValue(value string) {
57-
d.logger.Debugf("%s.%s", d.getIndentationString(), value)
58+
func (d *Decoder) LogDecodedValue(variableName string, value any) {
59+
indentation := d.getIndentationString()
60+
if variableName != "" {
61+
switch castedValue := value.(type) {
62+
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
63+
d.logger.Debugf("%s.%s (%d)", indentation, variableName, castedValue)
64+
case string:
65+
d.logger.Debugf("%s.%s (%s)", indentation, variableName, castedValue)
66+
case bool:
67+
d.logger.Debugf("%s.%s (%t)", indentation, variableName, castedValue)
68+
case tagBuffer:
69+
d.logger.Debugf("%s.%s", indentation, "TAG_BUFFER")
70+
default:
71+
d.logger.Debugf("%s%s (%v)", indentation, variableName, castedValue)
72+
}
73+
}
5874
}
5975

6076
func (d *Decoder) MuteLogger() {
@@ -85,9 +101,8 @@ func (d *Decoder) GetInt8(variableName string) (int8, error) {
85101

86102
decodedInteger := int8(d.bytes[d.offset])
87103
d.offset++
88-
if variableName != "" {
89-
d.LogDecodedValue(fmt.Sprintf("%s (%d)", variableName, decodedInteger))
90-
}
104+
105+
d.LogDecodedValue(variableName, decodedInteger)
91106
return decodedInteger, nil
92107
}
93108

@@ -100,9 +115,7 @@ func (d *Decoder) GetInt16(variableName string) (int16, error) {
100115

101116
decodedInteger := int16(binary.BigEndian.Uint16(d.bytes[d.offset:]))
102117
d.offset += 2
103-
if variableName != "" {
104-
d.LogDecodedValue(fmt.Sprintf("%s (%d)", variableName, decodedInteger))
105-
}
118+
d.LogDecodedValue(variableName, decodedInteger)
106119
return decodedInteger, nil
107120
}
108121

@@ -116,9 +129,7 @@ func (d *Decoder) GetInt32(variableName string) (int32, error) {
116129

117130
decodedInteger := int32(binary.BigEndian.Uint32(d.bytes[d.offset:]))
118131
d.offset += 4
119-
if variableName != "" {
120-
d.LogDecodedValue(fmt.Sprintf("%s (%d)", variableName, decodedInteger))
121-
}
132+
d.LogDecodedValue(variableName, decodedInteger)
122133
return decodedInteger, nil
123134
}
124135

@@ -131,9 +142,7 @@ func (d *Decoder) GetInt64(variableName string) (int64, error) {
131142

132143
decodedInteger := int64(binary.BigEndian.Uint64(d.bytes[d.offset:]))
133144
d.offset += 8
134-
if variableName != "" {
135-
d.LogDecodedValue(fmt.Sprintf("%s (%d)", variableName, decodedInteger))
136-
}
145+
d.LogDecodedValue(variableName, decodedInteger)
137146
return decodedInteger, nil
138147
}
139148

@@ -150,9 +159,7 @@ func (d *Decoder) GetUnsignedVarint(variableName string) (uint64, error) {
150159
}
151160

152161
d.offset += n
153-
if variableName != "" {
154-
d.LogDecodedValue(fmt.Sprintf("%s (%d)", variableName, decodedInteger))
155-
}
162+
d.LogDecodedValue(variableName, decodedInteger)
156163
return decodedInteger, nil
157164
}
158165

@@ -168,9 +175,7 @@ func (d *Decoder) GetSignedVarint(variableName string) (int64, error) {
168175
}
169176

170177
d.offset += n
171-
if variableName != "" {
172-
d.LogDecodedValue(fmt.Sprintf("%s (%d)", variableName, decodedInteger))
173-
}
178+
d.LogDecodedValue(variableName, decodedInteger)
174179
return decodedInteger, nil
175180
}
176181

@@ -195,9 +200,7 @@ func (d *Decoder) GetBool(variableName string) (bool, error) {
195200
}
196201

197202
d.offset++
198-
if variableName != "" {
199-
d.LogDecodedValue(fmt.Sprintf("%s (%t)", variableName, decodedBool))
200-
}
203+
d.LogDecodedValue(variableName, decodedBool)
201204
return decodedBool, nil
202205
}
203206

@@ -227,9 +230,7 @@ func (d *Decoder) GetArrayLength(variableName string) (int, error) {
227230
return -1, errors.NewPacketDecodingError(fmt.Sprintf("Invalid array length: %d", arrayLength)).AddContexts("ARRAY_LENGTH", variableName)
228231
}
229232

230-
if variableName != "" {
231-
d.LogDecodedValue(fmt.Sprintf("%s (%d)", variableName, arrayLength))
232-
}
233+
d.LogDecodedValue(variableName, arrayLength)
233234
return arrayLength, nil
234235
}
235236

@@ -247,9 +248,7 @@ func (d *Decoder) GetCompactArrayLength(variableName string) (int, error) {
247248
return 0, nil
248249
}
249250

250-
if variableName != "" {
251-
d.LogDecodedValue(fmt.Sprintf("%s (%d)", variableName, decodedInteger))
252-
}
251+
d.LogDecodedValue(variableName, decodedInteger)
253252
return int(decodedInteger) - 1, nil
254253
}
255254

@@ -299,7 +298,7 @@ func (d *Decoder) ConsumeTagBuffer() error {
299298
}
300299
}
301300

302-
d.LogDecodedValue("TAG_BUFFER")
301+
d.LogDecodedValue("", tagBuffer{})
303302
return nil
304303
}
305304

0 commit comments

Comments
 (0)