Skip to content

Commit 9fb792d

Browse files
authored
Implement topological sort. (k2-fsa#18)
1 parent 4daae23 commit 9fb792d

File tree

8 files changed

+462
-24
lines changed

8 files changed

+462
-24
lines changed

cmake/googletest.cmake

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@ function(download_googltest)
3333
message(STATUS "googletest is downloaded to ${googletest_SOURCE_DIR}")
3434
message(STATUS "googletest's binary dir is ${googletest_BINARY_DIR}")
3535

36+
if(APPLE)
37+
set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS
38+
endif()
39+
#[==[
40+
-- Generating done
41+
Policy CMP0042 is not set: MACOSX_RPATH is enabled by default. Run "cmake
42+
--help-policy CMP0042" for policy details. Use the cmake_policy command to
43+
set the policy and suppress this warning.
44+
45+
MACOSX_RPATH is not specified for the following targets:
46+
47+
gmock
48+
gmock_main
49+
gtest
50+
gtest_main
51+
52+
This warning is for project developers. Use -Wno-dev to suppress it.
53+
]==]
54+
3655
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
3756

3857
target_include_directories(gtest

k2/csrc/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ function(k2_add_fsa_test name)
2525
gtest
2626
gtest_main
2727
)
28-
2928
add_test(NAME "Test.${name}"
3029
COMMAND
3130
$<TARGET_FILE:${name}>

k2/csrc/fsa_algo.cc

Lines changed: 113 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <utility>
1414

1515
#include "glog/logging.h"
16+
#include "k2/csrc/properties.h"
1617

1718
namespace {
1819

@@ -57,7 +58,7 @@ void ConnectCore(const Fsa &fsa, std::vector<int32_t> *state_map) {
5758
auto state = current_state.state; // get a copy since we will destroy it
5859
stack.pop();
5960
if (!stack.empty()) {
60-
// if it has a parent, set the parent's coaccessible flag
61+
// if it has a parent, set the parent's co-accessible flag
6162
if (coaccessible[state]) {
6263
auto &parent = stack.top();
6364
coaccessible[parent.state] = true;
@@ -98,7 +99,7 @@ void ConnectCore(const Fsa &fsa, std::vector<int32_t> *state_map) {
9899

99100
void Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
100101
CHECK_NOTNULL(b);
101-
if (arc_map) arc_map->clear();
102+
if (arc_map != nullptr) arc_map->clear();
102103

103104
std::vector<int32_t> state_b_to_a;
104105
ConnectCore(a, &state_b_to_a);
@@ -108,7 +109,7 @@ void Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
108109
b->arcs.clear();
109110
b->arcs.reserve(a.arcs.size());
110111

111-
if (arc_map) {
112+
if (arc_map != nullptr) {
112113
arc_map->clear();
113114
arc_map->reserve(a.arcs.size());
114115
}
@@ -144,9 +145,7 @@ void Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
144145
arc.src_state = i;
145146
arc.dest_state = state_b;
146147
b->arcs.push_back(arc);
147-
if (arc_map) {
148-
arc_map->push_back(arc_begin);
149-
}
148+
if (arc_map != nullptr) arc_map->push_back(arc_begin);
150149
}
151150
}
152151
}
@@ -194,4 +193,112 @@ void ArcSort(const Fsa &a, Fsa *b,
194193
}
195194
if (arc_map != nullptr) arc_map->swap(indexes);
196195
}
196+
197+
bool TopSort(const Fsa &a, Fsa *b,
198+
std::vector<int32_t> *state_map /*= nullptr*/) {
199+
CHECK_NOTNULL(b);
200+
b->arc_indexes.clear();
201+
b->arcs.clear();
202+
203+
if (state_map != nullptr) state_map->clear();
204+
205+
if (IsEmpty(a)) return true;
206+
if (!IsConnected(a)) return false;
207+
208+
static constexpr int8_t kNotVisited = 0; // a node that has not been visited
209+
static constexpr int8_t kVisiting = 1; // a node that is under visiting
210+
static constexpr int8_t kVisited = 2; // a node that has been visited
211+
212+
auto num_states = a.NumStates();
213+
auto final_state = num_states - 1;
214+
std::vector<int8_t> state_status(num_states, kNotVisited);
215+
216+
// map order to state.
217+
// state 0 has the largest order, i.e., num_states - 1
218+
// final_state has the least order, i.e., 0
219+
std::vector<int32_t> order;
220+
order.reserve(num_states);
221+
222+
std::stack<DfsState> stack;
223+
stack.push({0, a.arc_indexes[0], a.arc_indexes[1]});
224+
state_status[0] = kVisiting;
225+
bool is_acyclic = true;
226+
while (is_acyclic && !stack.empty()) {
227+
auto &current_state = stack.top();
228+
if (current_state.arc_begin == current_state.arc_end) {
229+
// we have finished visiting this state
230+
state_status[current_state.state] = kVisited;
231+
order.push_back(current_state.state);
232+
stack.pop();
233+
continue;
234+
}
235+
const auto &arc = a.arcs[current_state.arc_begin];
236+
auto next_state = arc.dest_state;
237+
auto status = state_status[next_state];
238+
switch (status) {
239+
case kNotVisited: {
240+
// a new discovered node
241+
state_status[next_state] = kVisiting;
242+
auto arc_begin = a.arc_indexes[next_state];
243+
if (next_state != final_state)
244+
stack.push({next_state, arc_begin, a.arc_indexes[next_state + 1]});
245+
else
246+
stack.push({next_state, arc_begin, arc_begin});
247+
++current_state.arc_begin;
248+
break;
249+
}
250+
case kVisiting:
251+
// this is a back arc indicating a loop in the graph
252+
is_acyclic = false;
253+
break;
254+
case kVisited:
255+
// this is a forward cross arc, do nothing.
256+
++current_state.arc_begin;
257+
break;
258+
default:
259+
LOG(FATAL) << "Unreachable code is executed!";
260+
break;
261+
}
262+
}
263+
264+
if (!is_acyclic) return false;
265+
266+
std::vector<int32_t> state_a_to_b(num_states);
267+
for (auto i = 0; i != num_states; ++i) {
268+
state_a_to_b[order[num_states - 1 - i]] = i;
269+
}
270+
271+
// start state maps to start state
272+
CHECK_EQ(state_a_to_b.front(), 0);
273+
// final state maps to final state
274+
CHECK_EQ(state_a_to_b.back(), final_state);
275+
276+
b->arcs.reserve(a.arc_indexes.size());
277+
b->arc_indexes.resize(num_states);
278+
279+
int32_t arc_begin;
280+
int32_t arc_end;
281+
for (auto state_b = 0; state_b != num_states; ++state_b) {
282+
auto state_a = order[num_states - 1 - state_b];
283+
arc_begin = a.arc_indexes[state_a];
284+
if (state_a != final_state)
285+
arc_end = a.arc_indexes[state_a + 1];
286+
else
287+
arc_end = arc_begin;
288+
289+
b->arc_indexes[state_b] = static_cast<int32_t>(b->arcs.size());
290+
for (; arc_begin != arc_end; ++arc_begin) {
291+
auto arc = a.arcs[arc_begin];
292+
arc.src_state = state_b;
293+
arc.dest_state = state_a_to_b[arc.dest_state];
294+
b->arcs.push_back(arc);
295+
}
296+
}
297+
if (state_map != nullptr) {
298+
std::reverse(order.begin(), order.end());
299+
state_map->swap(order);
300+
}
301+
return true;
302+
}
303+
197304
} // namespace k2

k2/csrc/fsa_algo.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,24 @@ void RandomPath(const Fsa &a, const float *a_cost, Fsa *b,
213213
arc indexes in input fsa.
214214
*/
215215
void ArcSort(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map = nullptr);
216+
/**
217+
Sort the input fsa topologically.
218+
219+
It returns an empty fsa when the input fsa is not acyclic,
220+
is not connected, or is empty; otherwise it returns the topologically
221+
sorted fsa in `b`.
222+
223+
@param [in] a Input fsa to be topo sorted.
224+
@param [out] b Output fsa. It is set to empty if the input fsa is not
225+
acyclic or is not connected; otherwise it contains the
226+
topo sorted fsa.
227+
@param [out] state_map Maps from state indexes in the output fsa to
228+
state indexes in input fsa. It is empty if
229+
the output fsa is empty.
230+
@return true if the input fsa is acyclic and connected,
231+
or if the input is empty; return false otherwise.
232+
*/
233+
bool TopSort(const Fsa& a, Fsa* b, std::vector<int32_t>* state_map = nullptr);
216234

217235
/**
218236

k2/csrc/fsa_algo_test.cc

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88
#include "k2/csrc/fsa_algo.h"
99

10+
#include <string>
1011
#include <utility>
1112
#include <vector>
1213

1314
#include "gmock/gmock.h"
1415
#include "gtest/gtest.h"
16+
#include "k2/csrc/fsa_renderer.h"
17+
#include "k2/csrc/fsa_util.h"
1518

1619
namespace k2 {
1720

@@ -142,4 +145,99 @@ TEST(FsaAlgo, ArcSort) {
142145
}
143146
}
144147

148+
TEST(FsaAlgo, TopSort) {
149+
{
150+
// case 1: empty input fsa
151+
Fsa fsa;
152+
Fsa top_sorted;
153+
std::vector<int32_t> state_map(10);
154+
bool status = TopSort(fsa, &top_sorted, &state_map);
155+
156+
EXPECT_TRUE(status);
157+
EXPECT_TRUE(IsEmpty(top_sorted));
158+
EXPECT_TRUE(state_map.empty());
159+
}
160+
161+
{
162+
// case 2: non-connected fsa (not co-accessible)
163+
std::string s = R"(
164+
0 2 3
165+
1 2 1
166+
2
167+
)";
168+
auto fsa = StringToFsa(s);
169+
ASSERT_NE(fsa.get(), nullptr);
170+
171+
Fsa top_sorted;
172+
std::vector<int32_t> state_map(10);
173+
bool status = TopSort(*fsa, &top_sorted, &state_map);
174+
175+
EXPECT_FALSE(status);
176+
EXPECT_TRUE(IsEmpty(top_sorted));
177+
EXPECT_TRUE(state_map.empty());
178+
}
179+
180+
{
181+
// case 3: non-connected fsa (not accessible)
182+
std::string s = R"(
183+
0 2 3
184+
1 0 1
185+
2
186+
)";
187+
auto fsa = StringToFsa(s);
188+
ASSERT_NE(fsa.get(), nullptr);
189+
190+
Fsa top_sorted;
191+
std::vector<int32_t> state_map(10);
192+
bool status = TopSort(*fsa, &top_sorted, &state_map);
193+
194+
EXPECT_FALSE(status);
195+
EXPECT_TRUE(IsEmpty(top_sorted));
196+
EXPECT_TRUE(state_map.empty());
197+
}
198+
199+
{
200+
// case 4: connected fsa
201+
std::string s = R"(
202+
0 4 40
203+
0 2 20
204+
1 6 2
205+
2 3 30
206+
3 6 60
207+
3 1 10
208+
4 5 50
209+
5 2 8
210+
6
211+
)";
212+
auto fsa = StringToFsa(s);
213+
ASSERT_NE(fsa.get(), nullptr);
214+
215+
Fsa top_sorted;
216+
std::vector<int32_t> state_map;
217+
218+
TopSort(*fsa, &top_sorted, &state_map);
219+
220+
ASSERT_EQ(top_sorted.NumStates(), fsa->NumStates());
221+
222+
ASSERT_FALSE(state_map.empty());
223+
EXPECT_THAT(state_map, ::testing::ElementsAre(0, 4, 5, 2, 3, 1, 6));
224+
225+
ASSERT_FALSE(IsEmpty(top_sorted));
226+
227+
const auto &arc_indexes = top_sorted.arc_indexes;
228+
const auto &arcs = top_sorted.arcs;
229+
230+
ASSERT_EQ(arc_indexes.size(), 7u);
231+
EXPECT_THAT(arc_indexes, ::testing::ElementsAre(0, 2, 3, 4, 5, 7, 8));
232+
std::vector<Arc> expected_arcs = {
233+
{0, 1, 40}, {0, 3, 20}, {1, 2, 50}, {2, 3, 8},
234+
{3, 4, 30}, {4, 6, 60}, {4, 5, 10}, {5, 6, 2},
235+
};
236+
237+
for (auto i = 0; i != 8; ++i) {
238+
EXPECT_EQ(arcs[i], expected_arcs[i]);
239+
}
240+
}
241+
}
242+
145243
} // namespace k2

0 commit comments

Comments
 (0)