Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Experimental) Update message properties via Hostapi #868

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/problem-matchers.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@
"message": 4
}
]
},
{
"owner": "swig",
"pattern": [
{
"regexp": "^(.*):(\\d+):\\s+(Warning|Error):\\s+(.*)$",
"file": 1,
"line": 2,
"severity": 3,
"message": 4
}
]
}
]
}
11 changes: 10 additions & 1 deletion examples/circles_spatial3D/src/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,18 @@ FLAMEGPU_STEP_FUNCTION(Validation) {
prevTotalDrift = totalDrift;
// printf("Avg Drift: %g\n", totalDrift / FLAMEGPU->agent("Circle").count());
printf("%.2f%% Drift correct\n", 100 * driftDropped / static_cast<float>(driftDropped + driftIncreased));

// Change radius
static float radius_swap = 4.0f;
if (!((FLAMEGPU->getStepCounter()+1) % 200)) {
auto msg = FLAMEGPU->message<flamegpu::MessageSpatial3D>("location");
const float t = msg.getRadius();
msg.setRadius(radius_swap);
radius_swap = t;
}
}
int main(int argc, const char ** argv) {
flamegpu::ModelDescription model("Circles_BruteForce_example");
flamegpu::ModelDescription model("Circles_Spatial3D_example");

const unsigned int AGENT_COUNT = 16384;
const float ENV_MAX = static_cast<float>(floor(cbrt(AGENT_COUNT)));
Expand Down
20 changes: 18 additions & 2 deletions include/flamegpu/gpu/CUDAMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// include sub classes
#include "flamegpu/gpu/CUDAMessageList.h"
#include "flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceHost.h"
#include "flamegpu/runtime/messaging/MessageSpatial3D/MessageSpatial3DHost.h"

// forward declare classes from other modules

Expand Down Expand Up @@ -45,6 +46,13 @@ class CUDAMessage {
* Return an immutable reference to the message description represented by the CUDAMessage instance
*/
const MessageBruteForce::Data& getMessageDescription() const;
/**
* Return a HostAPI object for the message
*
* This can be used to update the bounds or similar.
*/
template<typename Msg>
typename Msg::HostAPI getHostAPI();
/**
* @return The currently allocated length of the message array (in the number of messages)
*/
Expand Down Expand Up @@ -94,8 +102,8 @@ class CUDAMessage {
*/
void mapWriteRuntimeVariables(const AgentFunctionData& func, const CUDAAgent& cuda_agent, const unsigned int &writeLen, cudaStream_t stream) const;
void *getReadPtr(const std::string &var_name);
const CUDAMessageMap &getReadList() { return message_list->getReadList(); }
const CUDAMessageMap &getWriteList() { return message_list->getWriteList(); }
const CUDAMessageList::MessageMap &getReadList() { return message_list->getReadList(); }
const CUDAMessageList::MessageMap&getWriteList() { return message_list->getWriteList(); }
/**
* Swaps the two internal maps within message_list
* @param isOptional If optional newMessageCount will be reduced based on scan_flag[streamId]
Expand Down Expand Up @@ -176,6 +184,14 @@ class CUDAMessage {
*/
const CUDASimulation& cudaSimulation;
};
template<typename Msg>
typename Msg::HostAPI CUDAMessage::getHostAPI() {
auto t = dynamic_cast<typename Msg::CUDAModelHandler*>(specialisation_handler.get());
if (t) {
return Msg::HostAPI(*this, *t);
}
THROW exception::InvalidMessageType("Message %s is not of type %s, in HostAPI::message()\n", message_description.name.c_str(), std::type_index(typeid(typename Msg)).name());
}

} // namespace flamegpu

Expand Down
36 changes: 18 additions & 18 deletions include/flamegpu/gpu/CUDAMessageList.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,24 @@ namespace flamegpu {
class CUDAScatter;
class CUDAMessage;

/**
* Map used to map a variable name to buffer
*/
typedef std::map <std::string, void*> CUDAMessageMap;
/**
* Key Value pair of CUDAMessageMap
*/
typedef std::pair <std::string, void*> CUDAMessageMapPair;

/**
* This is the internal device memory handler for CUDAMessage
* @todo This could just be merged with CUDAMessage
*/
class CUDAMessageList {
public:
/**
* Initially allocates message lists based on cuda_message.getMaximumListSize()
*/
/**
* Map used to map a variable name to buffer
*/
typedef std::map<std::string, void*> MessageMap;
/**
* Key Value pair of CUDAMessageMap
*/
typedef std::pair<std::string, void*> MessageMapPair;
/**
* Initially allocates message lists based on cuda_message.getMaximumListSize()
*/
explicit CUDAMessageList(CUDAMessage& cuda_message, CUDAScatter &scatter, cudaStream_t stream, unsigned int streamId);
/**
* Frees all message list memory
Expand Down Expand Up @@ -93,40 +93,40 @@ class CUDAMessageList {
/**
* @return Returns the map<variable_name, device_ptr> for reading message data
*/
const CUDAMessageMap &getReadList() { return d_list; }
const MessageMap &getReadList() { return d_list; }
/**
* @return Returns the map<variable_name, device_ptr> for writing message data (aka swap buffers)
*/
const CUDAMessageMap &getWriteList() { return d_swap_list; }
const MessageMap &getWriteList() { return d_swap_list; }

protected:
/**
* Allocates device memory for the provided message list
* @param memory_map Message list to perform operation on
*/
void allocateDeviceMessageList(CUDAMessageMap &memory_map);
void allocateDeviceMessageList(MessageMap &memory_map);
/**
* Frees device memory for the provided message list
* @param memory_map Message list to perform operation on
*/
void releaseDeviceMessageList(CUDAMessageMap &memory_map);
void releaseDeviceMessageList(MessageMap &memory_map);
/**
* Zeros device memory for the provided message list
* @param memory_map Message list to perform operation on
* @param stream The CUDAStream to use for CUDA operations
* @param skip_offset Number of items at the start of the list to not zero
*/
void zeroDeviceMessageList_async(CUDAMessageMap &memory_map, cudaStream_t stream, unsigned int skip_offset = 0);
void zeroDeviceMessageList_async(MessageMap &memory_map, cudaStream_t stream, unsigned int skip_offset = 0);

private:
/**
* Message storage for reading
*/
CUDAMessageMap d_list;
MessageMap d_list;
/**
* Message storage for writing
*/
CUDAMessageMap d_swap_list;
MessageMap d_swap_list;
/**
* Parent which this provides storage for
*/
Expand Down
20 changes: 20 additions & 0 deletions include/flamegpu/runtime/HostAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <vector>
#include <memory>

#include "flamegpu/gpu/CUDAMessage.h"
#include "flamegpu/gpu/CUDAMessageList.h"
#include "flamegpu/gpu/detail/CUDAErrorChecking.cuh"
#include "flamegpu/runtime/utility/HostRandom.cuh"
#include "flamegpu/runtime/utility/HostEnvironment.cuh"
Expand Down Expand Up @@ -53,6 +55,7 @@ class HostAPI {
CUDAScatter &scatter,
const AgentOffsetMap &agentOffsets,
AgentDataMap &agentData,
std::unordered_map<std::string, std::unique_ptr<CUDAMessage>> &_messageMap,
const std::shared_ptr<EnvironmentManager> &env,
CUDAMacroEnvironment &macro_env,
const unsigned int &streamId,
Expand All @@ -65,6 +68,11 @@ class HostAPI {
* Returns methods that work on all agents of a certain type currently in a given state
*/
HostAgentAPI agent(const std::string &agent_name, const std::string &stateName = ModelData::DEFAULT_STATE);
/**
* Returns methods that work on all agents of a certain type currently in a given state
*/
template<typename Msg>
typename Msg::HostAPI message(const std::string& message_name);
/**
* Host API access to seeded random number generation
*/
Expand Down Expand Up @@ -97,6 +105,10 @@ class HostAPI {
* when new agents are copied to device.
*/
AgentDataMap &agentData;
/**
* References to the model's messages
*/
std::unordered_map<std::string, std::unique_ptr<CUDAMessage>> &messageMap;
/**
* Cuda scatter singleton
*/
Expand All @@ -121,6 +133,14 @@ void HostAPI::resizeOutputSpace(const unsigned int &items) {
d_output_space_size = sizeof(T) * items;
}
}
template<typename Msg>
typename Msg::HostAPI HostAPI::message(const std::string& message_name) {
const auto &it = messageMap.find(message_name);
if (it == messageMap.end()) {
THROW exception::InvalidMessageName("Message with name '%s' was not found within the model, in HostAPI::message()\n", message_name.c_str());
}
return it->second->getHostAPI<Msg>();
}

} // namespace flamegpu

Expand Down
1 change: 1 addition & 0 deletions include/flamegpu/runtime/messaging/MessageSpatial3D.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class MessageSpatial3D {
struct Data; // Forward declare inner classes
class Description; // Forward declare inner classes
class CUDAModelHandler;
class HostAPI;
// Device
class In;
class Out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
#include <memory>
#include <string>

#include "flamegpu/gpu/CUDAMessage.h"
#include "flamegpu/util/nvtx.h"
#include "flamegpu/runtime/messaging/MessageSpatial3D.h"
#include "flamegpu/runtime/messaging/MessageSpatial2D/MessageSpatial2DHost.h"
#include "flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceHost.h"

class CUDAMessage;

namespace flamegpu {

/**
* CUDA host side handler of spatial messages
* Allocates memory for and constructs PBM
*/
class MessageSpatial3D::CUDAModelHandler : public MessageSpecialisationHandler {
friend class MessageSpatial3D::HostAPI;

public:
/**
* Constructor
Expand Down Expand Up @@ -60,6 +63,10 @@ class MessageSpatial3D::CUDAModelHandler : public MessageSpecialisationHandler {
* Returns a pointer to the metadata struct, this is required for reading the message data
*/
const void *getMetaDataDevicePtr() const override { return d_data; }
/**
* On next PBM rebuild the pbm may be reallocated and metadata updated on device
*/
void setMetadataChangedFlag() { metadata_changed_flag = true; }

private:
/**
Expand All @@ -79,6 +86,10 @@ class MessageSpatial3D::CUDAModelHandler : public MessageSpecialisationHandler {
* Number of bins, arrays are +1 this length
*/
unsigned int binCount = 0;
/**
* Number of bins in hd_data.PBM
*/
unsigned int allocated_binCount = 0;
/**
* Size of currently allocated temp storage memory for cub
*/
Expand Down Expand Up @@ -111,6 +122,10 @@ class MessageSpatial3D::CUDAModelHandler : public MessageSpecialisationHandler {
* Owning CUDAMessage, provides access to message storage etc
*/
CUDAMessage &sim_message;
/**
* If true, PBM rebuild will realloc a different sized PBM first
*/
bool metadata_changed_flag;
};

/**
Expand Down Expand Up @@ -201,6 +216,48 @@ class MessageSpatial3D::Description : public MessageBruteForce::Description {
float getMaxZ() const;
};

class MessageSpatial3D::HostAPI {
CUDAMessage &cudaMessage;
CUDAModelHandler &messageHandler;

public:
HostAPI(CUDAMessage &_cudaMessage, CUDAModelHandler &_messageHandler)
: cudaMessage(_cudaMessage)
, messageHandler(_messageHandler) { }
void setRadius(const float& r);
void setMinX(const float& x);
void setMinY(const float& y);
void setMinZ(const float& z);
void setMin(const float& x, const float& y, const float& z);
void setMaxX(const float& x);
void setMaxY(const float& y);
void setMaxZ(const float& z);
void setMax(const float& x, const float& y, const float& z);
void clearMessages();

float getRadius() const {
return messageHandler.hd_data.radius;
}
float getMinX() const {
return messageHandler.hd_data.min[0];
}
float getMinY() const {
return messageHandler.hd_data.min[1];
}
float getMinZ() const {
return messageHandler.hd_data.min[2];
}
float getMaxX() const {
return messageHandler.hd_data.max[0];
}
float getMaxY() const {
return messageHandler.hd_data.max[1];
}
float getMaxZ() const {
return messageHandler.hd_data.max[2];
}
};

} // namespace flamegpu

#endif // INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGESPATIAL3D_MESSAGESPATIAL3DHOST_H_
Loading