diff --git a/diskqueue.go b/diskqueue.go index ee6c22d..f4b5b4f 100644 --- a/diskqueue.go +++ b/diskqueue.go @@ -275,6 +275,42 @@ func (d *diskQueue) skipToNextRWFile() error { func (d *diskQueue) readOne() ([]byte, error) { var err error var msgSize int32 + + // Fix: since the d.maxBytesPerFileRead may be changed during calling d.writeOne(), + // we must check the current position before next reading to avoid an unexpected EOF. + if d.readFileNum < d.writeFileNum { + if d.maxBytesPerFileRead <= 0 { + d.maxBytesPerFileRead = d.maxBytesPerFile + readFile := d.fileName(d.readFileNum) + stat, err := os.Stat(readFile) + if err != nil { + d.logf(ERROR, "DISKQUEUE(%s) unable to stat(%s) - %s", d.name, readFile, err) + } else { + d.maxBytesPerFileRead = stat.Size() + } + } + + if d.readPos >= d.maxBytesPerFileRead { + if d.readFile != nil { + if err := d.readFile.Close(); err != nil { + d.logf(ERROR, "DISKQUEUE(%s) failed to close(%s) - %s", d.name, d.readFile.Name(), err) + } + err := os.Remove(d.readFile.Name()) + if err != nil { + d.logf(ERROR, "DISKQUEUE(%s) failed to Remove(%s) - %s", d.name, d.readFile.Name(), err) + } + d.readFile = nil + // sync every time we start reading from a new file + err = d.sync() + if err != nil { + d.logf(ERROR, "DISKQUEUE(%s) failed to sync - %s", d.name, err) + } + } + + d.readFileNum++ + d.readPos = 0 + } + } if d.readFile == nil { curFileName := d.fileName(d.readFileNum) diff --git a/diskqueue_test.go b/diskqueue_test.go index fc72406..89d3dc2 100644 --- a/diskqueue_test.go +++ b/diskqueue_test.go @@ -3,8 +3,10 @@ package diskqueue import ( "bufio" "bytes" + "crypto/rand" "fmt" "io/ioutil" + "log" "os" "path" "path/filepath" @@ -696,3 +698,80 @@ func benchmarkDiskQueueGet(size int64, b *testing.B) { <-dq.ReadChan() } } + +func TestDiskQueue_ReadChan(t *testing.T) { + dataDir := "./testdata/dat" + err := os.MkdirAll(dataDir, 0o755) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dataDir) + + var ( + megabyte int64 = 1 << 20 + datCount = 1112 + ) + + dq := New("nsqio_diskqueue", dataDir, 128*megabyte, 0, int32(16*megabyte), + 32*megabyte, time.Second*5, func(lvl LogLevel, f string, args ...interface{}) { + if lvl >= WARN { + t.Errorf(f, args) + return + } + log.Println(lvl, fmt.Sprintf(f, args...)) + }) + + buf := make([]byte, 3231197) + n, err := rand.Read(buf) + if err != nil { + t.Fatal(err) + } + if n != len(buf) { + t.Fatal("buf is not full") + } + + pushExit := make(chan struct{}) + go func() { + for i := 0; i < datCount; i++ { + if err := dq.Put(buf); err != nil { + t.Error(err) + return + } + } + close(pushExit) + }() + + var wg sync.WaitGroup + wg.Add(5) + + var counter atomic.Int64 + + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + for { + select { + case data := <-dq.ReadChan(): + if bytes.Compare(buf, data) != 0 { + t.Error("get corrupt msg") + return + } + counter.Add(1) + case <-pushExit: + if dq.Depth() == 0 { + return + } + } + } + }() + } + + wg.Wait() + + if counter.Load() != int64(datCount) { + t.Fatal("push message count not equals get message count") + } + if err := dq.Close(); err != nil { + t.Fatal(err) + } +}