From 2fa9d9f03f16044c60d296f095af2c1f9a44a839 Mon Sep 17 00:00:00 2001 From: Wouter Tichelaar <9594229+DifferentialityDevelopment@users.noreply.github.com> Date: Mon, 27 May 2024 23:03:41 +0200 Subject: [PATCH] feat: windows support. (#63) --- .github/workflows/main.yml | 36 ++++++++- .gitignore | 1 + Makefile | 22 ++++-- README.md | 35 ++++++++- src/app.cpp | 1 + src/apps/dllama-api/dllama-api.cpp | 8 +- src/common/pthread.h | 40 ++++++++++ src/funcs.cpp | 4 +- src/socket.cpp | 119 ++++++++++++++++++++++++----- src/tasks.cpp | 3 +- src/tokenizer.cpp | 5 +- src/transformer.cpp | 37 ++++----- src/utils.cpp | 84 +++++++++++++++++++- src/utils.hpp | 30 ++++++-- 14 files changed, 354 insertions(+), 71 deletions(-) create mode 100644 src/common/pthread.h diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 285cbc8..4b50411 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,8 +7,8 @@ on: branches: - main jobs: - build: - name: Build + build-linux: + name: Linux runs-on: ${{matrix.os}} strategy: matrix: @@ -22,9 +22,37 @@ jobs: uses: actions/checkout@v3 - name: Dependencies id: dependencies + run: sudo apt-get update && sudo apt-get install build-essential + - name: Build + id: build run: | - sudo apt-get update - sudo apt-get install build-essential + make dllama + make dllama-api + make funcs-test + make quants-test + make transformer-test + make llama2-tasks-test + make grok1-tasks-test + - name: funcs-test + run: ./funcs-test + - name: quants-test + run: ./quants-test + - name: transformer-test + run: ./transformer-test + - name: llama2-tasks-test + run: ./llama2-tasks-test + - name: grok1-tasks-test + run: ./grok1-tasks-test + + build-windows: + name: Windows + runs-on: windows-latest + steps: + - name: Checkout Repo + uses: actions/checkout@v3 + - name: Dependencies + id: dependencies + run: choco install make - name: Build id: build run: | diff --git a/.gitignore b/.gitignore index 77e86da..da9e5a2 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ run*.sh server /dllama /dllama-* +*.exe \ No newline at end of file diff --git a/Makefile b/Makefile index 96bdfb1..e938e8f 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,13 @@ CXX = g++ CXXFLAGS = -std=c++11 -Werror -O3 -march=native -mtune=native +# Conditional settings for Windows +ifeq ($(OS),Windows_NT) + LIBS = -lws2_32 # or -lpthreadGC2 if needed +else + LIBS = -lpthread +endif + utils: src/utils.cpp $(CXX) $(CXXFLAGS) -c src/utils.cpp -o utils.o quants: src/quants.cpp @@ -27,16 +34,17 @@ app: src/app.cpp $(CXX) $(CXXFLAGS) -c src/app.cpp -o app.o dllama: src/apps/dllama/dllama.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app - $(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o -lpthread + $(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS) dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app - $(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o -lpthread + $(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS) + funcs-test: src/funcs-test.cpp funcs utils quants - $(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o -lpthread + $(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o $(LIBS) quants-test: src/quants.cpp utils quants - $(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o -lpthread + $(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o $(LIBS) transformer-test: src/transformer-test.cpp funcs utils quants transformer socket - $(CXX) $(CXXFLAGS) src/transformer-test.cpp -o transformer-test funcs.o utils.o quants.o transformer.o socket.o -lpthread + $(CXX) $(CXXFLAGS) src/transformer-test.cpp -o transformer-test funcs.o utils.o quants.o transformer.o socket.o $(LIBS) llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer - $(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o -lpthread + $(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o $(LIBS) grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks tokenizer - $(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o -lpthread + $(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o $(LIBS) \ No newline at end of file diff --git a/README.md b/README.md index f5374ea..d4488cf 100644 --- a/README.md +++ b/README.md @@ -158,9 +158,11 @@ To add more worker nodes, just add more addresses to the `--workers` argument. [Share your results](https://github.com/b4rtaz/distributed-llama/discussions)! -## 💻 How to Run on MacOS or Linux +## 💻 How to Run on MacOS, Linux, or Windows -You need to have x86_64 AVX2 CPU or ARM CPU. Different devices may have different CPUs. The below instructions are for Debian-based distributions but you can easily adapt them to your distribution or macOS. +You need to have x86_64 AVX2 CPU or ARM CPU. Different devices may have different CPUs. The below instructions are for Debian-based distributions but you can easily adapt them to your distribution, macOS, or Windows. + +### MacOS and Linux 1. Install Git and G++: ```sh @@ -188,6 +190,35 @@ sudo nice -n -20 ./dllama inference --model ../dllama_llama-2-7b_q40.bin --token sudo nice -n -20 ./dllama chat --model ../dllama_llama-2-7b-chat_q40.bin --tokenizer ../dllama-llama2-tokenizer.t --weights-float-type q40 --buffer-float-type q80 --nthreads 4 --workers 192.168.0.1:9998 ``` +### Windows + +1. Install Git and Mingw (Chocolatey): + - https://chocolatey.org/install +```powershell +choco install mingw +``` +2. Clone this repository: +```sh +git clone https://github.com/b4rtaz/distributed-llama.git +``` +3. Compile Distributed Llama: +```sh +make dllama +``` +4. Transfer weights and the tokenizer file to the root node. +5. Run worker nodes on worker devices: +```sh +./dllama worker --port 9998 --nthreads 4 +``` +6. Run root node on the root device: +```sh +./dllama inference --model ../dllama_llama-2-7b_q40.bin --tokenizer ../dllama-llama2-tokenizer.t --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 192.168.0.1:9998 +``` +7. To run the root node in the chat mode: +```sh +./dllama chat --model ../dllama_llama-2-7b-chat_q40.bin --tokenizer ../dllama-llama2-tokenizer.t --weights-float-type q40 --buffer-float-type q80 --nthreads 4 --workers 192.168.0.1:9998 +``` + [Share your results](https://github.com/b4rtaz/distributed-llama/discussions)! ## 💡 License diff --git a/src/app.cpp b/src/app.cpp index ff1b59d..ca89b37 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "app.hpp" FloatType parseFloatType(char* val) { diff --git a/src/apps/dllama-api/dllama-api.cpp b/src/apps/dllama-api/dllama-api.cpp index c364e85..74daddc 100644 --- a/src/apps/dllama-api/dllama-api.cpp +++ b/src/apps/dllama-api/dllama-api.cpp @@ -5,10 +5,16 @@ #include #include #include +#include + +#ifdef _WIN32 +#include +#include +#else #include #include #include -#include +#endif #include "types.hpp" #include "../../utils.hpp" diff --git a/src/common/pthread.h b/src/common/pthread.h new file mode 100644 index 0000000..2e6a492 --- /dev/null +++ b/src/common/pthread.h @@ -0,0 +1,40 @@ +#ifndef PTHREAD_WRAPPER +#define PTHREAD_WRAPPER + +#ifdef _WIN32 +#include + +typedef HANDLE dl_thread; +typedef DWORD thread_ret_t; +typedef DWORD (WINAPI *thread_func_t)(void *); + +static int pthread_create(dl_thread * out, void * unused, thread_func_t func, void * arg) { + (void) unused; + dl_thread handle = CreateThread(NULL, 0, func, arg, 0, NULL); + if (handle == NULL) { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(dl_thread thread, void * unused) { + (void) unused; + DWORD ret = WaitForSingleObject(thread, INFINITE); + if (ret == WAIT_FAILED) { + return -1; + } + CloseHandle(thread); + return 0; +} +#else +#include + +typedef pthread_t dl_thread; +typedef void* thread_ret_t; +typedef void* (*thread_func_t)(void *); + +#endif + +#endif // PTHREAD_WRAPPER diff --git a/src/funcs.cpp b/src/funcs.cpp index ee8dd60..42f7881 100644 --- a/src/funcs.cpp +++ b/src/funcs.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include "common/pthread.h" #include "quants.hpp" #include "funcs.hpp" @@ -145,7 +145,7 @@ void rmsnorm(float* o, const float* x, const float ms, const float* weight, cons } struct MatmulThreadInfo { - pthread_t handler; + dl_thread handler; float* output; const void* input; const void* weights; diff --git a/src/socket.cpp b/src/socket.cpp index 4f944c5..1409ea5 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -1,21 +1,47 @@ #include #include #include -#include -#include #include #include #include #include -#include #include #include +#include +#include #include "socket.hpp" +#ifdef _WIN32 +#include +#include // For inet_addr and other functions +#include // For SSIZE_T +typedef SSIZE_T ssize_t; +#define close closesocket +#else +#include +#include +#include +#include +#endif + #define SOCKET_LAST_ERRCODE errno #define SOCKET_LAST_ERROR strerror(errno) +static inline bool isEagainError() { + #ifdef _WIN32 + return WSAGetLastError() == WSAEWOULDBLOCK; + #else + return SOCKET_LAST_ERRCODE == EAGAIN; + #endif +} + static inline void setNonBlocking(int socket, bool enabled) { +#ifdef _WIN32 + u_long mode = enabled ? 1 : 0; + if (ioctlsocket(socket, FIONBIO, &mode) != 0) { + throw std::runtime_error("Error setting socket to non-blocking"); + } +#else int flags = fcntl(socket, F_GETFL, 0); if (enabled) { flags |= O_NONBLOCK; @@ -24,6 +50,7 @@ static inline void setNonBlocking(int socket, bool enabled) { } if (fcntl(socket, F_SETFL, flags) < 0) throw std::runtime_error("Error setting socket to non-blocking"); +#endif } static inline void setNoDelay(int socket) { @@ -33,26 +60,45 @@ static inline void setNoDelay(int socket) { } static inline void setQuickAck(int socket) { +#ifndef _WIN32 #ifdef TCP_QUICKACK int value = 1; if (setsockopt(socket, IPPROTO_TCP, TCP_QUICKACK, (char*)&value, sizeof(int)) < 0) throw std::runtime_error("Error setting quick ack"); #endif +#endif +} + +static inline void setReuseAddr(int socket) { + int opt = 1; + #ifdef _WIN32 + int iresult = setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char*)&opt, sizeof(opt)); + if (iresult == SOCKET_ERROR) { + closesocket(socket); + WSACleanup(); + throw std::runtime_error("setsockopt failed: " + std::to_string(WSAGetLastError())); + } + #else + if (setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { + close(socket); + throw std::runtime_error("setsockopt failed: " + std::string(strerror(errno))); + } + #endif } static inline void writeSocket(int socket, const void* data, size_t size) { while (size > 0) { - int s = send(socket, (char*)data, size, 0); + int s = send(socket, (const char*)data, size, 0); if (s < 0) { - if (SOCKET_LAST_ERRCODE == EAGAIN) { + if (isEagainError()) { continue; } - throw WriteSocketException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR); + throw WriteSocketException(0, "Error writing to socket"); } else if (s == 0) { throw WriteSocketException(0, "Socket closed"); } size -= s; - data = (char*)data + s; + data = (const char*)data + s; } } @@ -60,9 +106,9 @@ static inline bool tryReadSocket(int socket, void* data, size_t size, unsigned l // maxAttempts = 0 means infinite attempts size_t s = size; while (s > 0) { - int r = recv(socket, data, s, 0); + int r = recv(socket, (char*)data, s, 0); if (r < 0) { - if (SOCKET_LAST_ERRCODE == EAGAIN) { + if (isEagainError()) { if (s == size && maxAttempts > 0) { maxAttempts--; if (maxAttempts == 0) { @@ -71,7 +117,7 @@ static inline bool tryReadSocket(int socket, void* data, size_t size, unsigned l } continue; } - throw ReadSocketException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR); + throw ReadSocketException(0, "Error reading from socket"); } else if (r == 0) { throw ReadSocketException(0, "Socket closed"); } @@ -82,7 +128,9 @@ static inline bool tryReadSocket(int socket, void* data, size_t size, unsigned l } static inline void readSocket(int socket, void* data, size_t size) { - assert(tryReadSocket(socket, data, size, 0)); + if (!tryReadSocket(socket, data, size, 0)) { + throw std::runtime_error("Error reading from socket"); + } } ReadSocketException::ReadSocketException(int code, const char* message) { @@ -169,9 +217,9 @@ void SocketPool::writeMany(unsigned int n, SocketIo* ios) { if (io->size > 0) { isWriting = true; int socket = sockets[io->socketIndex]; - ssize_t s = send(socket, io->data, io->size, 0); + ssize_t s = send(socket, (const char*)io->data, io->size, 0); if (s < 0) { - if (SOCKET_LAST_ERRCODE == EAGAIN) { + if (isEagainError()) { continue; } throw WriteSocketException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR); @@ -201,7 +249,7 @@ void SocketPool::readMany(unsigned int n, SocketIo* ios) { int socket = sockets[io->socketIndex]; ssize_t r = recv(socket, (char*)io->data, io->size, 0); if (r < 0) { - if (SOCKET_LAST_ERRCODE == EAGAIN) { + if (isEagainError()) { continue; } throw ReadSocketException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR); @@ -294,30 +342,59 @@ SocketServer::SocketServer(int port) { const char* host = "0.0.0.0"; struct sockaddr_in serverAddr; - socket = ::socket(AF_INET, SOCK_STREAM, 0); + #ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + throw std::runtime_error("WSAStartup failed: " + std::to_string(WSAGetLastError())); + } + #endif + + socket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (socket < 0) throw std::runtime_error("Cannot create socket"); + setReuseAddr(socket); memset(&serverAddr, 0, sizeof(serverAddr)); serverAddr.sin_family = AF_INET; serverAddr.sin_port = htons(port); serverAddr.sin_addr.s_addr = inet_addr(host); - int bindResult = bind(socket, (struct sockaddr*)&serverAddr, sizeof(serverAddr)); + int bindResult; + #ifdef _WIN32 + bindResult = bind(socket, (SOCKADDR*)&serverAddr, sizeof(serverAddr)); + if (bindResult == SOCKET_ERROR) { + int error = WSAGetLastError(); + closesocket(socket); + WSACleanup(); + throw std::runtime_error("Cannot bind port: " + std::to_string(error)); + } + #else + bindResult = bind(socket, (struct sockaddr*)&serverAddr, sizeof(serverAddr)); if (bindResult < 0) { - printf("Cannot bind %s:%d\n", host, port); - throw std::runtime_error("Cannot bind port"); + close(socket); + throw std::runtime_error("Cannot bind port: " + std::string(strerror(errno))); } + #endif - int listenResult = listen(socket, 1); + int listenResult = listen(socket, SOMAXCONN); if (listenResult != 0) { - printf("Cannot listen %s:%d\n", host, port); - throw std::runtime_error("Cannot listen port"); + #ifdef _WIN32 + closesocket(socket); + WSACleanup(); + throw std::runtime_error("Cannot listen on port: " + std::to_string(WSAGetLastError())); + #else + close(socket); + throw std::runtime_error("Cannot listen on port: " + std::string(strerror(errno))); + #endif } + printf("Listening on %s:%d...\n", host, port); } SocketServer::~SocketServer() { shutdown(socket, 2); + #ifdef _WIN32 + WSACleanup(); + #endif close(socket); } diff --git a/src/tasks.cpp b/src/tasks.cpp index 4bbc911..5b4b187 100644 --- a/src/tasks.cpp +++ b/src/tasks.cpp @@ -1,7 +1,8 @@ -#include "tasks.hpp" #include #include #include +#include +#include "tasks.hpp" TransformerArch::TransformerArch() { inference.nTasks = 0; diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index dc597a1..8cdaae9 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -356,10 +355,10 @@ int Sampler::sample(float* logits) { return next; } -void Sampler::setTemp(float temp){ +void Sampler::setTemp(float temp) { this->temperature = temp; } -void Sampler::setSeed(unsigned long long seed){ +void Sampler::setSeed(unsigned long long seed) { this->rngState = seed; } \ No newline at end of file diff --git a/src/transformer.cpp b/src/transformer.cpp index 7233dea..2f95c53 100644 --- a/src/transformer.cpp +++ b/src/transformer.cpp @@ -2,10 +2,7 @@ #include #include #include -#include #include -#include -#include #include "funcs.hpp" #include "utils.hpp" #include "socket.hpp" @@ -285,10 +282,14 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i printf("💡 ropeTheta: %.1f\n", spec.ropeTheta); fseek(fd, 0, SEEK_END); - size_t fileSize = ftell(fd); + long fileSize = ftell(fd); + if (fileSize == -1L) { + fclose(fd); + throw std::runtime_error("Error determining model file size"); + } fclose(fd); - spec.fileSize = fileSize; + spec.fileSize = static_cast(fileSize); return spec; } @@ -607,24 +608,18 @@ static size_t readSlicedMatmulWeights(MatmulSlice* slice, char* weights0, Socket } Transformer Transformer::loadRootFromFile(const char* path, TransformerSpec* spec, SocketPool* socketPool) { - int fd = open(path, O_RDONLY); - if (fd == -1) { - printf("Cannot open file %s\n", path); - exit(EXIT_FAILURE); - } - char* data = (char*)mmap(NULL, spec->fileSize, PROT_READ, MAP_PRIVATE, fd, 0); - if (data == MAP_FAILED) { - printf("Mmap failed!\n"); - exit(EXIT_FAILURE); - } - char* weights = data + spec->headerSize; - Transformer transformer = Transformer::loadRoot(weights, spec, socketPool); + MmapFile file; + openMmapFile(&file, path, spec->fileSize); + + char* weights = ((char*)file.data) + spec->headerSize; + Transformer transformer = Transformer::loadRoot((char*)weights, spec, socketPool); + #if ALLOC_WEIGHTS - munmap(data, spec->fileSize); - close(fd); + closeMmapFile(&file); #else - // TODO: handler should be released in deconstructor + // TODO: handler should be released in destructor #endif + return transformer; } @@ -682,7 +677,7 @@ Transformer Transformer::loadRoot(char* data, TransformerSpec* spec, SocketPool* long missedBytes = (long)(w - data) - spec->fileSize + spec->headerSize; if (missedBytes != 0) { - printf("Missed %ld bytes\n", missedBytes); + printf("The model file is missing %ld bytes\n", missedBytes); exit(EXIT_FAILURE); } diff --git a/src/utils.cpp b/src/utils.cpp index 40a8f1a..31518f2 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,13 +1,30 @@ #include #include +#include +#include +#include #include -#include #include "utils.hpp" #define BUFFER_ALIGNMENT 16 -char* newBuffer(size_t size) { - char* buffer; +#ifdef _WIN32 +#include +#else +#include +#include +#include +#endif + +void* newBuffer(size_t size) { + void* buffer; +#ifdef _WIN32 + buffer = _aligned_malloc(size, BUFFER_ALIGNMENT); + if (buffer == NULL) { + fprintf(stderr, "error: _aligned_malloc failed\n"); + exit(EXIT_FAILURE); + } +#else if (posix_memalign((void**)&buffer, BUFFER_ALIGNMENT, size) != 0) { fprintf(stderr, "error: posix_memalign failed\n"); exit(EXIT_FAILURE); @@ -15,9 +32,18 @@ char* newBuffer(size_t size) { if (mlock(buffer, size) != 0) { fprintf(stderr, "🚧 Cannot allocate %zu bytes directly in RAM\n", size); } +#endif return buffer; } +void freeBuffer(void* buffer) { +#ifdef _WIN32 + _aligned_free(buffer); +#else + free(buffer); +#endif +} + unsigned long timeMs() { struct timeval te; gettimeofday(&te, NULL); @@ -37,6 +63,56 @@ float randomF32(unsigned long long *state) { return (randomU32(state) >> 8) / 16777216.0f; } +void openMmapFile(MmapFile* file, const char* path, size_t size) { + file->size = size; +#ifdef _WIN32 + file->hFile = CreateFileA(path, GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); + if (file->hFile == INVALID_HANDLE_VALUE) { + printf("Cannot open file %s\n", path); + exit(EXIT_FAILURE); + } + + file->hMapping = CreateFileMappingA(file->hFile, NULL, PAGE_READONLY, 0, 0, NULL); + if (file->hMapping == NULL) { + printf("CreateFileMappingA failed, error: %lu\n", GetLastError()); + CloseHandle(file->hFile); + exit(EXIT_FAILURE); + } + + file->data = (char*)MapViewOfFile(file->hMapping, FILE_MAP_READ, 0, 0, 0); + if (file->data == NULL) { + printf("MapViewOfFile failed!\n"); + CloseHandle(file->hMapping); + CloseHandle(file->hFile); + exit(EXIT_FAILURE); + } +#else + file->fd = open(path, O_RDONLY); + if (file->fd == -1) { + printf("Cannot open file %s\n", path); + exit(EXIT_FAILURE); + } + + file->data = mmap(NULL, size, PROT_READ, MAP_PRIVATE, file->fd, 0); + if (file->data == MAP_FAILED) { + printf("Mmap failed!\n"); + close(file->fd); + exit(EXIT_FAILURE); + } +#endif +} + +void closeMmapFile(MmapFile* file) { +#ifdef _WIN32 + UnmapViewOfFile(file->data); + CloseHandle(file->hMapping); + CloseHandle(file->hFile); +#else + munmap(file->data, file->size); + close(file->fd); +#endif +} + TaskLoop::TaskLoop(unsigned int nThreads, unsigned int nTasks, unsigned int nTypes, TaskLoopTask* tasks, void* userData) { this->nThreads = nThreads; this->nTasks = nTasks; @@ -69,7 +145,7 @@ void TaskLoop::run() { } for (i = 1; i < nThreads; i++) { - int result = pthread_create(&threads[i].handler, NULL, threadHandler, (void*)&threads[i]); + int result = pthread_create(&threads[i].handler, NULL, (thread_func_t)threadHandler, (void*)&threads[i]); if (result != 0) { printf("Cannot created thread\n"); exit(EXIT_FAILURE); diff --git a/src/utils.hpp b/src/utils.hpp index fd6cfae..e9af964 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -2,16 +2,36 @@ #define UTILS_HPP #include -#include +#include "common/pthread.h" -#define NEW_BUFFER(size) newBuffer(size) -#define FREE_BUFFER(buffer) free(buffer) +#ifdef _WIN32 +#include +#endif + +#define NEW_BUFFER(size) (char*)newBuffer(size) +#define FREE_BUFFER(buffer) freeBuffer(buffer) + +void* newBuffer(size_t size); +void freeBuffer(void* buffer); -char* newBuffer(size_t size); unsigned long timeMs(); unsigned int randomU32(unsigned long long *state); float randomF32(unsigned long long *state); +struct MmapFile { + void* data; + size_t size; +#ifdef _WIN32 + HANDLE hFile; + HANDLE hMapping; +#else + int fd; +#endif +}; + +void openMmapFile(MmapFile* file, const char* path, size_t size); +void closeMmapFile(MmapFile* file); + typedef void (TaskLoopHandler)(unsigned int nThreads, unsigned int threadIndex, void* userData); typedef struct { TaskLoopHandler* handler; @@ -23,7 +43,7 @@ class TaskLoop; struct TaskLoopThread { unsigned int threadIndex; unsigned int nTasks; - pthread_t handler; + dl_thread handler; TaskLoop* loop; };