-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Message passing #5
Comments
Hi, based on a previous request, there's a full working example on graph classification with sort-pooling here: https://github.com/MKLab-ITI/JGNN/blob/main/JGNN/src/examples/graphClassification/SortPooling.java For an even simpler example with mean pooling use this: https://github.com/MKLab-ITI/JGNN/blob/main/JGNN/src/examples/graphClassification/MeanPooling.java Message passing for node classification would look like this (this is not a good architecture, but demonstrates all concepts): https://github.com/MKLab-ITI/JGNN/blob/main/JGNN/src/examples/nodeClassification/MessagePassing.java Some explanations: However To get a sense of the complexity, what the above-linked message passing architecture looks like can be found here. Recall that all these will happen in the CPU. P.S. If you have a benchmark in mind, please provide each graph's edges and graph labels too. |
Thanks for the examples, I tried several networks from the examples, which unfortunately are unable to learn the data set, I always get [0.5,0.5] in the output of the classifier. It may be that I'm loading the features and adjacency matrix wrong? Let me try the SortPooling, and I will post the results. |
Do let us know whether sort pooling ends up working for you, but sum pooling is also fine usually (in fact, any equivariant operation suffices for graph classification). In terms of debugging:
If all these are ok, I think the only way you can be getting [0.5, 0.5] is because your features probably need some normalization, because they have very different value ranges. Especially the first feature may just be dominating computations to yield two very large but similar values in both outcome outputs. Note: There is always a small chance that the problem you are trying to solve is not perfectly solvable with GNNs (contrary to traditional neural networks, GNNs can not approximate every graph function - simple architectures have only WL-1 or WL-2 expressive power, and universal approximation usually requires some form of computational unboundedness). However, even when not perfectly solvable, your architectures should be learning something. If all else fails, I recommend trying a simple GCN which I know from personal experience to always learn something, even in very hard tasks, despite not always being the best. (But do try normalizing your features first, because they are probably the issue here.) P.S. If you want to, upload some code here so that I can take a look too. |
Thanks for the suggestions. The message-passing example deals with one graph only, if I understand this correctly. Is there a way to train it on many graphs? |
You can just copy-paste the layer definitions. Here's an example: ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", 1)
.config("classes", 2)
//.config("reduced", reduced) // only needed for sort pooling
.config("hidden", hidden)
.config("2hidden", 2*hidden)
.config("reg", 0.005)
.operation("edgeSrc = from(A)")
.operation("edgeDst = to(A)")
.layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))")
.layer("h{l+1}=h{l}@matrix(hidden, hidden, reg)+vector(hidden)")
// 1st part of message passing layer (make it as complex as needed): message transformation
.operation("message{l}=h{l}[edgeSrc] | h{l}[edgeDst]")
.operation("transformed{l}=relu(message{l}@matrix(2hidden, hidden, reg)+vector(hidden))")
// 2nd part of message passing layer: receive the message, concatenate it with own representation and transform it
.operation("received{l}=reduce(transformed{l}, A)")
.operation("i{l}=relu((received{l} | h{l})@matrix(2hidden, hidden, reg)+vector(hidden))")
.layer("h{l+1}=relu(i{l}@matrix(hidden, hidden, reg)+vector(hidden))")
// the commented part would be sort pooling
/*.config("hiddenReduced", hidden*reduced) // reduced * (previous layer's output size)
.operation("z{l}=sort(h{l}, reduced)") // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
.layer("h{l+1}=softmax(h{l}, row)")*/
// the following two layers implement sum pooling (transform to classes dimensions first)
.layer("h{l+1}=sum(h{l}@matrix(hidden, classes)+vector(classes), row)")
.layer("h{l+1}=softmax(h{l}, row)")
.out("h{l}"); For the example it learns pretty fastly and accurately where the previous example wasn't performing so well (for this smaller dataset it reaches 97% accuracy before epoch 40): TrajectoryData dtrain = new TrajectoryData(800);
TrajectoryData dtest = new TrajectoryData(200); |
I tried the above example, but it's having a hard time training on my data, which is essentially a trajectory as well. https://drive.google.com/drive/folders/1KgtInDoiQ3WgztrW-itzohYfBFUAHNC-?usp=sharing I have a class there that reads the data. |
Hi Emmanouil, |
Precondition: update to JGNN 1.2.0Before running code from this reply, please update to the latest version. FixIn terms of code, I don't understand why you are using a public static Matrix createAdjancencyMatrix(){
Matrix m = new DenseMatrix(6,6);
for(int i = 0; i < 6; i++)
m.put(i, i, 1.0);
for(int i = 0; i < 5; i++) {
m.put(i, i+1, 1.0);
m.put(i+1, i, 1.0);
}
return m;
} After the fix, I tried your code and sort pooling seemed to be better: int reduced = 6; // equal to the number of nodes to not lose information (so sorting will just make a lossless permutation invariant layer - do not do this if you have a lot of nodes and just keep the number small, because it can be very computationally intensive in that case)
int hidden = 8;
ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", 2)
.config("classes", 2)
.config("reduced", reduced)
.config("hidden", hidden)
.config("2hidden", 2*hidden)
.config("reg", 0.005)
.operation("edgeSrc = from(A)")
.operation("edgeDst = to(A)")
.layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))")
.layer("h{l+1}=h{l}@matrix(hidden, hidden, reg)+vector(hidden)")
// 1st part of message passing layer (make it as complex as needed): message transformation
.operation("message{l}=h{l}[edgeSrc] | h{l}[edgeDst]")
.operation("transformed{l}=relu(message{l}@matrix(2hidden, hidden, reg)+vector(hidden))")
// 2nd part of message passing layer: receive the message, concatenate it with own representation and transform it
.operation("received{l}=reduce(transformed{l}, A)")
.operation("i{l}=relu((received{l} | h{l})@matrix(2hidden, hidden, reg)+vector(hidden))")
.layer("h{l+1}=relu(i{l}@matrix(hidden, hidden, reg)+vector(hidden))")
// the commented part would be sort pooling
.config("hiddenReduced", hidden*reduced) // reduced * (previous layer's output size)
.operation("z{l}=sort(h{l}, reduced)") // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
.layer("h{l+1}=softmax(h{l}, row)")
.out("h{l}"); Random suggestions
Model model = builder.getModel().init(new XavierNormal());
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.1));
Loss loss = new CategoricalCrossEntropy();//.setMeanReduction(true);
for(int epoch=0; epoch<300; epoch++) {
// gradient update over all graphs
for(int graphId=0; graphId<tr.features.size(); graphId++) {
if(Math.random()<0.1) // keep 10% of graphs in each iteration's batch
continue;
int graphIdentifier = graphId;
// each gradient calculation into a new thread pool task
ThreadPool.getInstance().submit(new Runnable() {
@Override
public void run() {
Matrix adjacency = tr.adjacency.get(graphIdentifier);
Matrix features= tr.features.get(graphIdentifier);
Tensor graphLabel = tr.labels.get(graphIdentifier).asRow();
model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(graphLabel));
}
});
}
ThreadPool.getInstance().waitForConclusion(); // wait for all gradients to compute
optimizer.updateAll(); // apply gradients on model parameters
// evaluate every 10 epochs only
if(epoch%10!=0)
continue;
double acc = 0.0;
for(int graphId=0; graphId<ts.adjacency.size(); graphId++) {
Matrix adjacency = ts.adjacency.get(graphId);
Matrix features= ts.features.get(graphId);
Tensor graphLabel = ts.labels.get(graphId);
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
acc += 1;
}
System.out.println("iter = " + epoch + " " + acc/ts.adjacency.size());
} This yields training set accuracy a little over 90% at epoch 300 for me. (I believe it would keep increasing.)
Re: AuthorshipSince I am not doing any actual research for which I could take responsibility, please just add me in the Acknowledgements ("The authors would like to thank Emmanouil Krasanakis from CERTH for..."). I would appreciate this a lot. Do sent me a version of your paper. Please cite JGNN's publication in the text:
Thanks a lot for sharing your research directions. As someone working almost exclusively on enabling technologies it feels nice knowing that they do, in fact, have value in other scientific fields. :-) |
This worked, and the network learned. However, I think I may be going at this the wrong way. Maybe you can suggest the right direction. the problem I'm trying to solve is the following. I have a collection of N nodes (which I can make fixed say 128 nodes), I need to define reasonable connections between the nodes (edges) and then teach the network to assign weights to the connections. The minimum of 6 nodes have to be present to make a valid collection and maximum of 12 nodes. Is there an architecture that can work for this kind of problem? As of now I do not have a paper draft yet, I'm doing the feasibility studies, but if this works, it will turn into a paper (I'll definitely cite your work and acknowledgement). |
Hi,
I noticed that message passing was added to the library. I'm interested in graph classification with message passing.
I was wondering if an example can be provided for this.
I have the following data, a graph consisting of 6 nodes each having 2 features, I need to construct an adjacency matrix and feature matrix and train with message passing network. An example file is attached for 10 graphs, the first column is the graph number and the last two columns are features. I have 10 graphs in this example.
How would I write code for loading this data and training it?
The real goal is to also identify graphs where one of the nodes is missing.
1,1,229.9100283176008,1.1902541030399987
1,2,230.47182154007461,1.058776646471805
1,3,203.189198652389,1.2398531706028166
1,4,204.5314318362828,1.153553470015844
1,5,185.64669584185978,1.2793633742545163
1,6,187.79421175584724,1.2198439576738058
2,1,229.4793702231205,2.330960340799244
2,2,230.33288586087744,2.210450003254505
2,3,203.11540479737127,2.3660695156585776
2,4,205.1700163108635,2.2992488439519665
2,5,185.68083877718777,2.3336671387774333
2,6,187.82261076079203,2.2953794181018194
3,1,229.32647579597085,-1.6042361520855672
3,2,230.66106436284386,-1.7481226493919342
3,3,202.93044042971965,-1.6135964029608023
3,4,204.93338840950247,-1.7067393831431075
3,5,185.57789923371803,-1.6103196248243055
3,6,188.24516335884968,-1.675625332691681
4,1,229.32882367683308,-1.6078158761501915
4,2,230.61681376907453,-1.736008671442035
4,3,202.95094620375633,-1.6236500454970282
4,4,204.89899232792726,-1.7011735459693618
4,5,185.6409650804477,-1.6271353064485021
4,6,188.25057680655323,-1.6762484217377236
5,1,230.01054032804672,2.466171048528725
5,2,230.54468404411324,2.342179997948466
5,3,202.93204096445686,2.5120332831973915
5,4,204.53248646119764,2.4362613310255226
5,5,185.79810866906044,2.5332757754491295
5,6,188.47975863206108,2.488929424115569
6,1,229.8512738011256,-0.0784807601982349
6,2,230.47171490662362,-0.19786009641288566
6,3,203.5063959707409,-0.05580110328207439
6,4,204.72183631454655,-0.1278216247793105
6,5,186.32332306504196,-0.031602383785374444
6,6,188.19302000871338,-0.07593593188648337
7,1,229.8407373900458,-0.08073549839860646
7,2,230.45752937146577,-0.2037788378524268
7,3,203.39274855559626,-0.07192003653233625
7,4,204.648717367102,-0.14605199324120002
7,5,186.00685748111547,-0.06285312285012232
7,6,187.97473760322154,-0.10808596076112409
8,1,229.32564275283303,0.2819165991854963
8,2,230.65149601075646,0.13935347898665917
8,3,202.89854691446166,0.2428692590302859
8,4,204.62697129166526,0.15248565532110503
8,5,185.63234871110154,0.22893237255190277
8,6,187.79581150015034,0.16823066116574453
9,1,229.57959752774198,2.9943233835287355
9,2,230.36740736050317,2.886351025206375
9,3,203.03815025014387,3.0001340936311043
9,4,204.52414181460338,2.9430842900344767
9,5,185.57389183018174,3.005285072530584
9,6,187.79948976767747,2.9790278914250785
10,1,229.40097705981987,-0.21886740450924871
10,2,230.35164722875328,-0.35836770227138204
10,3,202.89411232709537,-0.21348577568332344
10,4,204.64020598357496,-0.3003337905165334
10,5,185.56139436854855,-0.20822861329548692
10,6,188.13909142174575,-0.2662555763758106
The text was updated successfully, but these errors were encountered: