Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,12 @@ if(MLX_BUILD_CPU)
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
block(PROPAGATE openblas_SOURCE_DIR)
set(BUILD_SHARED_LIBS OFF) # link statically
set(BUILD_STATIC_LIBS ON)
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
endblock()
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
Expand Down
8 changes: 8 additions & 0 deletions benchmarks/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ function(build_benchmark SRCFILE)
set(target "${src_name}")
add_executable(${target} ${SRCFILE})
target_link_libraries(${target} PRIVATE mlx)
# On Windows, copy the mlx DLL to the executable directory for runtime loading
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope there is a better solution but I'm good with what it is now. I have created a issue to track this: #3031.

if(WIN32 AND BUILD_SHARED_LIBS)
add_custom_command(
TARGET ${target}
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:mlx>
$<TARGET_FILE_DIR:${target}>)
endif()
endfunction(build_benchmark)

build_benchmark(single_ops.cpp)
Expand Down
1 change: 1 addition & 0 deletions docs/Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
PREDEFINED = MLX_API=

################################################################################
# Compound extraction control. #
Expand Down
8 changes: 8 additions & 0 deletions examples/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ function(build_example SRCFILE)
set(target "${src_name}")
add_executable(${target} ${SRCFILE})
target_link_libraries(${target} PRIVATE mlx)
# On Windows, copy the mlx DLL to the executable directory for runtime loading
if(WIN32 AND BUILD_SHARED_LIBS)
add_custom_command(
TARGET ${target}
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:mlx>
$<TARGET_FILE_DIR:${target}>)
endif()
endfunction(build_example)

build_example(tutorial.cpp)
Expand Down
22 changes: 16 additions & 6 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,18 @@ target_sources(
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
# mlx_version needs access to mlx_export.h for MLX_API export macro
target_include_directories(mlx_version PRIVATE ${PROJECT_SOURCE_DIR})
# On Windows shared lib builds, mlx_version also needs MLX_EXPORT for proper DLL
# linkage
if(WIN32 AND BUILD_SHARED_LIBS)
target_compile_definitions(mlx_version PRIVATE MLX_EXPORT)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)

if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in
# 'std::pair<float, float>' when C++17 is enabled changed to match C++14 in
# GCC 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
endif()
Expand All @@ -39,11 +46,14 @@ if(MSVC)
# expression to only apply to C/CXX, not CUDA (NVCC doesn't understand /bigobj
# directly).
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/bigobj>)
endif()

if(WIN32)
# Export symbols by default to behave like macOS/linux.
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
# Windows DLLs have a 65535 symbol export limit. We use explicit exports via
# MLX_API macro (__declspec(dllexport)) on public API functions only. This
# avoids exporting internal template instantiations.
if(BUILD_SHARED_LIBS)
target_compile_definitions(mlx PRIVATE MLX_EXPORT)
else()
target_compile_definitions(mlx PUBLIC MLX_STATIC)
endif()
endif()

add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
Expand Down
8 changes: 5 additions & 3 deletions mlx/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

#include <cstdlib>

#include "mlx/mlx_export.h"

namespace mlx::core::allocator {

// Simple wrapper around buffer pointers
// WARNING: Only Buffer objects constructed from and those that wrap
// raw pointers from mlx::allocator are supported.
class Buffer {
class MLX_API Buffer {
private:
void* ptr_;

Expand All @@ -28,7 +30,7 @@ class Buffer {
};
};

class Allocator {
class MLX_API Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
Expand All @@ -47,7 +49,7 @@ class Allocator {
virtual ~Allocator() = default;
};

Allocator& allocator();
MLX_API Allocator& allocator();

inline Buffer malloc(size_t size) {
return allocator().malloc(size);
Expand Down
7 changes: 4 additions & 3 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
#include "mlx/mlx_export.h"
#include "mlx/small_vector.h"

namespace mlx::core {
Expand All @@ -22,7 +23,7 @@ using ShapeElem = int32_t;
using Shape = SmallVector<ShapeElem>;
using Strides = SmallVector<int64_t>;

class array {
class MLX_API array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
* object */

Expand Down Expand Up @@ -153,7 +154,7 @@ class array {
template <typename T>
T item() const;

struct ArrayIterator {
struct MLX_API ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = size_t;
using value_type = const array;
Expand Down Expand Up @@ -464,7 +465,7 @@ class array {
template <typename It>
void init(const It src);

struct ArrayDesc {
struct MLX_API ArrayDesc {
Shape shape;
Strides strides;
size_t size;
Expand Down
12 changes: 11 additions & 1 deletion mlx/backend/cpu/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,23 @@ struct VisualStudioInfo {
arch = "x64";
#endif
// Get path of Visual Studio.
// Use -latest to get only the most recent installation when multiple
// versions are installed, avoiding path concatenation issues.
std::string vs_path = JitCompiler::exec(fmt::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath",
" -latest -property installationPath",
std::getenv("ProgramFiles(x86)")));
if (vs_path.empty()) {
throw std::runtime_error("Can not find Visual Studio.");
}
// Trim any trailing whitespace/newlines from the path
vs_path.erase(
std::find_if(
vs_path.rbegin(),
vs_path.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
vs_path.end());
// Read the envs from vcvarsall.
std::string envs = JitCompiler::exec(fmt::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
Expand Down
4 changes: 3 additions & 1 deletion mlx/backend/cuda/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

#pragma once

#include "mlx/mlx_export.h"

namespace mlx::core::cu {

/* Check if the CUDA backend is available. */
bool is_available();
MLX_API bool is_available();

} // namespace mlx::core::cu
4 changes: 3 additions & 1 deletion mlx/backend/gpu/available.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

#pragma once

#include "mlx/mlx_export.h"

namespace mlx::core::gpu {

bool is_available();
MLX_API bool is_available();

} // namespace mlx::core::gpu
13 changes: 8 additions & 5 deletions mlx/backend/metal/metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
#include <unordered_map>
#include <variant>

#include "mlx/mlx_export.h"

namespace mlx::core::metal {

/* Check if the Metal backend is available. */
bool is_available();
MLX_API bool is_available();

/** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = "");
void stop_capture();
MLX_API void start_capture(std::string path = "");
MLX_API void stop_capture();

/** Get information about the GPU and system settings. */
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info();
MLX_API const
std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info();

} // namespace mlx::core::metal
1 change: 1 addition & 0 deletions mlx/backend/no_gpu/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <mutex>

#include "mlx/allocator.h"
#include "mlx/memory.h"

#ifdef __APPLE__
#include "mlx/backend/no_gpu/apple_memory.h"
Expand Down
11 changes: 6 additions & 5 deletions mlx/compile.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
#pragma once

#include "mlx/array.h"
#include "mlx/mlx_export.h"

namespace mlx::core {

enum class CompileMode { disabled, no_simplify, no_fuse, enabled };

/** Compile takes a function and returns a compiled function. */
std::function<std::vector<array>(const std::vector<array>&)> compile(
MLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(
std::function<std::vector<array>(const std::vector<array>&)> fun,
bool shapeless = false);

std::function<std::vector<array>(const std::vector<array>&)> compile(
MLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(
std::vector<array> (*fun)(const std::vector<array>&),
bool shapeless = false);

Expand All @@ -32,13 +33,13 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compile();
MLX_API void disable_compile();

/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compile();
MLX_API void enable_compile();

/** Set the compiler mode to the given value. */
void set_compile_mode(CompileMode mode);
MLX_API void set_compile_mode(CompileMode mode);
} // namespace mlx::core
9 changes: 5 additions & 4 deletions mlx/compile_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <unordered_map>

#include "mlx/array.h"
#include "mlx/mlx_export.h"

namespace mlx::core::detail {

Expand All @@ -14,24 +15,24 @@ using ArrayFnWithExtra =

// This is not part of the general C++ API as calling with a bad id is a bad
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
MLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::uintptr_t fun_id,
bool shapeless = false,
std::vector<uint64_t> constants = {});

ArrayFnWithExtra compile(
MLX_API ArrayFnWithExtra compile(
ArrayFnWithExtra fun,
std::uintptr_t fun_id,
bool shapeless,
std::vector<uint64_t> constants);

// Erase cached compile functions
void compile_erase(std::uintptr_t fun_id);
MLX_API void compile_erase(std::uintptr_t fun_id);

// Clear the compiler cache causing a recompilation of all compiled functions
// when called again.
void compile_clear_cache();
MLX_API void compile_clear_cache();

bool compile_available_for_device(const Device& device);

Expand Down
14 changes: 8 additions & 6 deletions mlx/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

#pragma once

#include "mlx/mlx_export.h"

namespace mlx::core {

struct Device {
struct MLX_API Device {
enum class DeviceType {
cpu,
gpu,
Expand All @@ -19,13 +21,13 @@ struct Device {
int index;
};

const Device& default_device();
MLX_API const Device& default_device();

void set_default_device(const Device& d);
MLX_API void set_default_device(const Device& d);

bool operator==(const Device& lhs, const Device& rhs);
bool operator!=(const Device& lhs, const Device& rhs);
MLX_API bool operator==(const Device& lhs, const Device& rhs);
MLX_API bool operator!=(const Device& lhs, const Device& rhs);

bool is_available(const Device& d);
MLX_API bool is_available(const Device& d);

} // namespace mlx::core
9 changes: 5 additions & 4 deletions mlx/distributed/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <memory>

#include "mlx/array.h"
#include "mlx/mlx_export.h"
#include "mlx/utils.h"

namespace mlx::core::distributed {
Expand All @@ -15,15 +16,15 @@ class GroupImpl;
};

/* Check if a communication backend is available */
bool is_available();
bool is_available(const std::string& bk);
MLX_API bool is_available();
MLX_API bool is_available(const std::string& bk);

/**
* A distributed::Group represents a group of independent mlx processes that
* can communicate. We must also be able to create sub-groups from a group in
* order to define more granular communication.
*/
struct Group {
struct MLX_API Group {
Group(std::shared_ptr<detail::GroupImpl> group) : group_(std::move(group)) {}

int rank() const;
Expand Down Expand Up @@ -55,6 +56,6 @@ struct Group {
* distributed subsystem. Otherwise simply return a singleton group which will
* render communication operations as no-op.
*/
Group init(bool strict = false, const std::string& bk = "any");
MLX_API Group init(bool strict = false, const std::string& bk = "any");

} // namespace mlx::core::distributed
Loading
Loading