1313#include < utility>
1414
1515#include " glog/logging.h"
16+ #include " k2/csrc/properties.h"
1617
1718namespace {
1819
@@ -57,7 +58,7 @@ void ConnectCore(const Fsa &fsa, std::vector<int32_t> *state_map) {
5758 auto state = current_state.state ; // get a copy since we will destroy it
5859 stack.pop ();
5960 if (!stack.empty ()) {
60- // if it has a parent, set the parent's coaccessible flag
61+ // if it has a parent, set the parent's co-accessible flag
6162 if (coaccessible[state]) {
6263 auto &parent = stack.top ();
6364 coaccessible[parent.state ] = true ;
@@ -98,7 +99,7 @@ void ConnectCore(const Fsa &fsa, std::vector<int32_t> *state_map) {
9899
99100void Connect (const Fsa &a, Fsa *b, std::vector<int32_t > *arc_map /* =nullptr*/ ) {
100101 CHECK_NOTNULL (b);
101- if (arc_map) arc_map->clear ();
102+ if (arc_map != nullptr ) arc_map->clear ();
102103
103104 std::vector<int32_t > state_b_to_a;
104105 ConnectCore (a, &state_b_to_a);
@@ -108,7 +109,7 @@ void Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
108109 b->arcs .clear ();
109110 b->arcs .reserve (a.arcs .size ());
110111
111- if (arc_map) {
112+ if (arc_map != nullptr ) {
112113 arc_map->clear ();
113114 arc_map->reserve (a.arcs .size ());
114115 }
@@ -144,9 +145,7 @@ void Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
144145 arc.src_state = i;
145146 arc.dest_state = state_b;
146147 b->arcs .push_back (arc);
147- if (arc_map) {
148- arc_map->push_back (arc_begin);
149- }
148+ if (arc_map != nullptr ) arc_map->push_back (arc_begin);
150149 }
151150 }
152151}
@@ -194,4 +193,112 @@ void ArcSort(const Fsa &a, Fsa *b,
194193 }
195194 if (arc_map != nullptr ) arc_map->swap (indexes);
196195}
196+
197+ bool TopSort (const Fsa &a, Fsa *b,
198+ std::vector<int32_t > *state_map /* = nullptr*/ ) {
199+ CHECK_NOTNULL (b);
200+ b->arc_indexes .clear ();
201+ b->arcs .clear ();
202+
203+ if (state_map != nullptr ) state_map->clear ();
204+
205+ if (IsEmpty (a)) return true ;
206+ if (!IsConnected (a)) return false ;
207+
208+ static constexpr int8_t kNotVisited = 0 ; // a node that has not been visited
209+ static constexpr int8_t kVisiting = 1 ; // a node that is under visiting
210+ static constexpr int8_t kVisited = 2 ; // a node that has been visited
211+
212+ auto num_states = a.NumStates ();
213+ auto final_state = num_states - 1 ;
214+ std::vector<int8_t > state_status (num_states, kNotVisited );
215+
216+ // map order to state.
217+ // state 0 has the largest order, i.e., num_states - 1
218+ // final_state has the least order, i.e., 0
219+ std::vector<int32_t > order;
220+ order.reserve (num_states);
221+
222+ std::stack<DfsState> stack;
223+ stack.push ({0 , a.arc_indexes [0 ], a.arc_indexes [1 ]});
224+ state_status[0 ] = kVisiting ;
225+ bool is_acyclic = true ;
226+ while (is_acyclic && !stack.empty ()) {
227+ auto ¤t_state = stack.top ();
228+ if (current_state.arc_begin == current_state.arc_end ) {
229+ // we have finished visiting this state
230+ state_status[current_state.state ] = kVisited ;
231+ order.push_back (current_state.state );
232+ stack.pop ();
233+ continue ;
234+ }
235+ const auto &arc = a.arcs [current_state.arc_begin ];
236+ auto next_state = arc.dest_state ;
237+ auto status = state_status[next_state];
238+ switch (status) {
239+ case kNotVisited : {
240+ // a new discovered node
241+ state_status[next_state] = kVisiting ;
242+ auto arc_begin = a.arc_indexes [next_state];
243+ if (next_state != final_state)
244+ stack.push ({next_state, arc_begin, a.arc_indexes [next_state + 1 ]});
245+ else
246+ stack.push ({next_state, arc_begin, arc_begin});
247+ ++current_state.arc_begin ;
248+ break ;
249+ }
250+ case kVisiting :
251+ // this is a back arc indicating a loop in the graph
252+ is_acyclic = false ;
253+ break ;
254+ case kVisited :
255+ // this is a forward cross arc, do nothing.
256+ ++current_state.arc_begin ;
257+ break ;
258+ default :
259+ LOG (FATAL) << " Unreachable code is executed!" ;
260+ break ;
261+ }
262+ }
263+
264+ if (!is_acyclic) return false ;
265+
266+ std::vector<int32_t > state_a_to_b (num_states);
267+ for (auto i = 0 ; i != num_states; ++i) {
268+ state_a_to_b[order[num_states - 1 - i]] = i;
269+ }
270+
271+ // start state maps to start state
272+ CHECK_EQ (state_a_to_b.front (), 0 );
273+ // final state maps to final state
274+ CHECK_EQ (state_a_to_b.back (), final_state);
275+
276+ b->arcs .reserve (a.arc_indexes .size ());
277+ b->arc_indexes .resize (num_states);
278+
279+ int32_t arc_begin;
280+ int32_t arc_end;
281+ for (auto state_b = 0 ; state_b != num_states; ++state_b) {
282+ auto state_a = order[num_states - 1 - state_b];
283+ arc_begin = a.arc_indexes [state_a];
284+ if (state_a != final_state)
285+ arc_end = a.arc_indexes [state_a + 1 ];
286+ else
287+ arc_end = arc_begin;
288+
289+ b->arc_indexes [state_b] = static_cast <int32_t >(b->arcs .size ());
290+ for (; arc_begin != arc_end; ++arc_begin) {
291+ auto arc = a.arcs [arc_begin];
292+ arc.src_state = state_b;
293+ arc.dest_state = state_a_to_b[arc.dest_state ];
294+ b->arcs .push_back (arc);
295+ }
296+ }
297+ if (state_map != nullptr ) {
298+ std::reverse (order.begin (), order.end ());
299+ state_map->swap (order);
300+ }
301+ return true ;
302+ }
303+
197304} // namespace k2
0 commit comments