Skip to content

Commit

Permalink
refactor: implement delta add/subtract by actually adding and subtrac…
Browse files Browse the repository at this point in the history
…ting weights (#4486)

Previously, subtraction is done by copying weights from the newer model, and addition is done by overwriting the older weights. The results are identical, but the new implementation is more intuitive. It also makes model merging with different base models easier to implement in the future.
  • Loading branch information
byronxu99 authored Feb 2, 2023
1 parent fc5296b commit a2dc620
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions vowpalwabbit/core/src/reductions/gd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,17 @@ void merge_weights_with_save_resume(size_t length,
}

template <typename WeightsT>
void copy_weights(WeightsT& dest, const WeightsT& source, size_t length)
void add_weights(WeightsT& dest, const WeightsT& lhs, const WeightsT& rhs, size_t length)
{
const size_t full_weights_size = length << dest.stride_shift();
for (size_t i = 0; i < full_weights_size; i++) { dest[i] = source[i]; }
for (size_t i = 0; i < full_weights_size; i++) { dest[i] = lhs[i] + rhs[i]; }
}

template <typename WeightsT>
void subtract_weights(WeightsT& dest, const WeightsT& lhs, const WeightsT& rhs, size_t length)
{
const size_t full_weights_size = length << dest.stride_shift();
for (size_t i = 0; i < full_weights_size; i++) { dest[i] = lhs[i] - rhs[i]; }
}

void sync_weights(VW::workspace& all)
Expand Down Expand Up @@ -274,13 +281,15 @@ void merge(const std::vector<float>& per_model_weighting, const std::vector<cons
}
}

void add(const VW::workspace& /* ws1 */, const VW::reductions::gd& data1, const VW::workspace& ws2,
VW::reductions::gd& data2, VW::workspace& ws_out, VW::reductions::gd& data_out)
void add(const VW::workspace& ws1, const VW::reductions::gd& data1, const VW::workspace& ws2, VW::reductions::gd& data2,
VW::workspace& ws_out, VW::reductions::gd& data_out)
{
const size_t length = static_cast<size_t>(1) << ws_out.num_bits;
// When adding, output the weights from the model delta (2nd arugment to addition)
if (ws_out.weights.sparse) { copy_weights(ws_out.weights.sparse_weights, ws2.weights.sparse_weights, length); }
else { copy_weights(ws_out.weights.dense_weights, ws2.weights.dense_weights, length); }
if (ws_out.weights.sparse)
{
add_weights(ws_out.weights.sparse_weights, ws1.weights.sparse_weights, ws2.weights.sparse_weights, length);
}
else { add_weights(ws_out.weights.dense_weights, ws1.weights.dense_weights, ws2.weights.dense_weights, length); }

for (size_t i = 0; i < data_out.per_model_states.size(); i++)
{
Expand All @@ -293,13 +302,15 @@ void add(const VW::workspace& /* ws1 */, const VW::reductions::gd& data1, const
}
}

void subtract(const VW::workspace& ws1, const VW::reductions::gd& data1, const VW::workspace& /* ws2 */,
void subtract(const VW::workspace& ws1, const VW::reductions::gd& data1, const VW::workspace& ws2,
VW::reductions::gd& data2, VW::workspace& ws_out, VW::reductions::gd& data_out)
{
const size_t length = static_cast<size_t>(1) << ws_out.num_bits;
// When subtracting, output the weights from the newer model (1st arugment to subtraction)
if (ws_out.weights.sparse) { copy_weights(ws_out.weights.sparse_weights, ws1.weights.sparse_weights, length); }
else { copy_weights(ws_out.weights.dense_weights, ws1.weights.dense_weights, length); }
if (ws_out.weights.sparse)
{
subtract_weights(ws_out.weights.sparse_weights, ws1.weights.sparse_weights, ws2.weights.sparse_weights, length);
}
else { subtract_weights(ws_out.weights.dense_weights, ws1.weights.dense_weights, ws2.weights.dense_weights, length); }

for (size_t i = 0; i < data_out.per_model_states.size(); i++)
{
Expand Down

0 comments on commit a2dc620

Please sign in to comment.