-
Notifications
You must be signed in to change notification settings - Fork 7
/
Alteryx2SparkBatch.py
89 lines (70 loc) · 3.48 KB
/
Alteryx2SparkBatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import xml.etree.ElementTree as ET
from Node import NodeElement
import csv
import matplotlib.pyplot as plt
from shutil import copyfile
import sys
import networkx as nx
import os
# Create for loop to iterate over a list of files in input\\
input_directory = "input\\" # Specify the directory
for x in os.listdir(input_directory): # Loop through the files in the directory
if x.endswith(".xml") or x.endswith(".yxmd"): # Check for xml or yxmd files
filepath = input_directory + x # Full path to the file
print("Start " + filepath)
filename = x.split(".")[0]
output_file_name = "output\\" + filename + ".csv" # output file name
dag_name = "output\\" + filename + ".png" # Dag Nam
file = filepath
assert len(file.split('.')) > 1, 'Input file must have an extension'
file_ext = file.split('.')[-1]
assert file_ext == 'xml' or file_ext == 'yxmd', 'Input file must be .xml or .yxmd'
if file_ext == 'yxmd':
xml = file.split('.')[0] + '.xml'
copyfile(file, xml) #Copy to .xml
os.remove(file) #Remove .ymxd
tree = ET.parse(xml)
else:
tree = ET.parse(file)
assert len(output_file_name.split('.')) > 1, 'Output file must have an extension'
output_file_ext = output_file_name.split('.')[-1]
assert output_file_ext == 'csv', 'Output file must be .csv'
graph = nx.DiGraph()
root = tree.getroot()
#print(root)
lst = []
for x in root.iter('Node'):
node = NodeElement(x,root)
lst.append(node.data)
graph.add_node(node.data['Tool ID'])
for connection in root.find('Connections').iter('Connection'):
connected_tool_id = connection.find('Origin').attrib.get('ToolID')
graph.add_edge(connected_tool_id, connection.find('Destination').attrib.get('ToolID'))
mst=[]
for node in nx.algorithms.topological_sort(graph):
# print("massa"+node)
# print(int(node))
mst = mst + ([d for d in lst if int(d.get('Tool ID')) == int(node)])
G = nx.DiGraph()
for connection in root.find('Connections').findall('Connection'):
origin_tool_id = connection.find('Origin').attrib['ToolID']
destination_tool_id = connection.find('Destination').attrib['ToolID']
G.add_edge(origin_tool_id, destination_tool_id)
pos = nx.spring_layout(nx.algorithms.topological_sort(graph), seed=142)
for node_data in lst:
tool_id = node_data['Tool ID']
x = node_data.get('x')
y = node_data.get('y')
if x is not None and y is not None:
pos[tool_id] = (int(x), int(y))
plt.figure(figsize=(60, 20))
nx.draw(G, pos, with_labels=True, node_size=1000, node_color='skyblue', font_size=10, font_color='black', font_weight='bold', arrowsize=20)
edge_labels = {(u, v): v for u, v in G.edges}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')
plt.title("Directed Acyclic Graph (DAG)")
plt.savefig(dag_name, format='png', bbox_inches='tight')
with open(output_file_name, 'w', newline='') as output_file:
dict_writer = csv.DictWriter(output_file, mst[0].keys())
dict_writer.writeheader()
dict_writer.writerows(mst)
print("End " + filepath)