Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Eigen blas usage more robust trough -Xclang Definition #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,31 @@ message("Found LLVM at: " ${Enzyme_LLVM_BINARY_DIR})
set(CMAKE_C_COMPILER "${Enzyme_LLVM_BINARY_DIR}/bin/clang")
set(CMAKE_CXX_COMPILER "${Enzyme_LLVM_BINARY_DIR}/bin/clang++")

project(EnzymeExample)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --save-temps -Xclang -D -Xclang EIGEN_USE_BLAS -Xclang -no-opaque-pointers")
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} --save-temps -Xclang -D -Xclang EIGEN_USE_BLAS -Xclang -no-opaque-pointers")

project(EnzymeExample CXX)

set(BLA_VENDER OpenBLAS)
find_package(BLAS REQUIRED)
if(BLAS_FOUND)
message("OpenBLAS found.")
#include_directories(/opt/OpenBLAS/include/)
#target_link_libraries(main.exe ${BLAS_LIBRARIES})
endif(BLAS_FOUND)

message("found dir ${Enzyme_DIR}")
message("found ${Enzyme_FOUND}")
get_property(importTargetsAfter DIRECTORY "${CMAKE_SOURCE_DIR}" PROPERTY IMPORTED_TARGETS)
message("imported targets ${importTargetsAfter}")


find_package (Eigen3 3.3 REQUIRED NO_MODULE)


add_executable(example
multisource.c
multisource.cpp
myblas.c
myblas.h
)
target_link_libraries(example PUBLIC LLDEnzymeFlags)
target_link_libraries(example PUBLIC LLDEnzymeFlags Eigen3::Eigen ${BLAS_LIBRARIES})
48 changes: 0 additions & 48 deletions multisource.c

This file was deleted.

70 changes: 70 additions & 0 deletions multisource.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "myblas.h"
#include <assert.h>
#include <iostream>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

// #define EIGEN_USE_BLAS

#include<Eigen/Core>

double dotabs(struct complex* alpha, struct complex* beta, int n) {
struct complex prod = myblas_cdot(alpha, beta, n);
return myblas_cabs(prod);
}

void __enzyme_autodiff(void*, ...);
int enzyme_const, enzyme_dup, enzyme_out;

using Eigen::MatrixXd;
using Eigen::VectorXd;

void foo(MatrixXd *m, VectorXd *v) { *v = *m * *v; }

int main(int argc, char *argv[]) {
// int size = 50;
//// VectorXf is a vector of floats, with dynamic size.
// Eigen::VectorXf u(size), v(size), w(size);
// u = v + w;

MatrixXd m = MatrixXd::Random(30, 30);
MatrixXd dm = MatrixXd::Random(30, 30);
m = (m + MatrixXd::Constant(30, 30, 1.2)) * 50;
std::cout << "m =" << std::endl << m << std::endl;
VectorXd v = VectorXd::Random(30);
VectorXd dv = VectorXd::Random(30);
// v << 1, 2, 3;
// std::cout << "m * v =" << std::endl << m * v << std::endl;

__enzyme_autodiff((void *)foo, &m, &dm, &v, &dv);
std::cout << "dm, dv: =" << std::endl << dm << std::endl << dv << std::endl;

// int n = 3;
// if (argc > 1) {
// n = atoi(argv[1]);
// }

// struct complex *A = (struct complex*)malloc(sizeof(struct complex) * n);
// assert(A != 0);
// for(int i=0; i<n; i++)
// A[i] = (struct complex){(i+1), (i+2)};

// struct complex *grad_A = (struct complex*)malloc(sizeof(struct complex) *
// n); assert(grad_A != 0); for(int i=0; i<n; i++)
// grad_A[i] = (struct complex){0,0};

// struct complex *B = (struct complex*)malloc(sizeof(struct complex) * n);
// assert(B != 0);
// for(int i=0; i<n; i++)
// B[i] = (struct complex){-3-i, 2*i};

// struct complex *grad_B = (struct complex*)malloc(sizeof(struct complex) *
// n); assert(grad_B != 0); for(int i=0; i<n; i++)
// grad_B[i] = (struct complex){0,0};

//__enzyme_autodiff((void*)dotabs, A, grad_A, B, grad_B, n);
// printf("Gradient dotabs(A)[0] = %f\n", grad_A[0].r);

return 0;
}