Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message passing #5

Open
gavalian opened this issue Mar 26, 2024 · 9 comments
Open

Message passing #5

gavalian opened this issue Mar 26, 2024 · 9 comments
Assignees

Comments

@gavalian
Copy link

Hi,
I noticed that message passing was added to the library. I'm interested in graph classification with message passing.
I was wondering if an example can be provided for this.
I have the following data, a graph consisting of 6 nodes each having 2 features, I need to construct an adjacency matrix and feature matrix and train with message passing network. An example file is attached for 10 graphs, the first column is the graph number and the last two columns are features. I have 10 graphs in this example.
How would I write code for loading this data and training it?
The real goal is to also identify graphs where one of the nodes is missing.

1,1,229.9100283176008,1.1902541030399987
1,2,230.47182154007461,1.058776646471805
1,3,203.189198652389,1.2398531706028166
1,4,204.5314318362828,1.153553470015844
1,5,185.64669584185978,1.2793633742545163
1,6,187.79421175584724,1.2198439576738058
2,1,229.4793702231205,2.330960340799244
2,2,230.33288586087744,2.210450003254505
2,3,203.11540479737127,2.3660695156585776
2,4,205.1700163108635,2.2992488439519665
2,5,185.68083877718777,2.3336671387774333
2,6,187.82261076079203,2.2953794181018194
3,1,229.32647579597085,-1.6042361520855672
3,2,230.66106436284386,-1.7481226493919342
3,3,202.93044042971965,-1.6135964029608023
3,4,204.93338840950247,-1.7067393831431075
3,5,185.57789923371803,-1.6103196248243055
3,6,188.24516335884968,-1.675625332691681
4,1,229.32882367683308,-1.6078158761501915
4,2,230.61681376907453,-1.736008671442035
4,3,202.95094620375633,-1.6236500454970282
4,4,204.89899232792726,-1.7011735459693618
4,5,185.6409650804477,-1.6271353064485021
4,6,188.25057680655323,-1.6762484217377236
5,1,230.01054032804672,2.466171048528725
5,2,230.54468404411324,2.342179997948466
5,3,202.93204096445686,2.5120332831973915
5,4,204.53248646119764,2.4362613310255226
5,5,185.79810866906044,2.5332757754491295
5,6,188.47975863206108,2.488929424115569
6,1,229.8512738011256,-0.0784807601982349
6,2,230.47171490662362,-0.19786009641288566
6,3,203.5063959707409,-0.05580110328207439
6,4,204.72183631454655,-0.1278216247793105
6,5,186.32332306504196,-0.031602383785374444
6,6,188.19302000871338,-0.07593593188648337
7,1,229.8407373900458,-0.08073549839860646
7,2,230.45752937146577,-0.2037788378524268
7,3,203.39274855559626,-0.07192003653233625
7,4,204.648717367102,-0.14605199324120002
7,5,186.00685748111547,-0.06285312285012232
7,6,187.97473760322154,-0.10808596076112409
8,1,229.32564275283303,0.2819165991854963
8,2,230.65149601075646,0.13935347898665917
8,3,202.89854691446166,0.2428692590302859
8,4,204.62697129166526,0.15248565532110503
8,5,185.63234871110154,0.22893237255190277
8,6,187.79581150015034,0.16823066116574453
9,1,229.57959752774198,2.9943233835287355
9,2,230.36740736050317,2.886351025206375
9,3,203.03815025014387,3.0001340936311043
9,4,204.52414181460338,2.9430842900344767
9,5,185.57389183018174,3.005285072530584
9,6,187.79948976767747,2.9790278914250785
10,1,229.40097705981987,-0.21886740450924871
10,2,230.35164722875328,-0.35836770227138204
10,3,202.89411232709537,-0.21348577568332344
10,4,204.64020598357496,-0.3003337905165334
10,5,185.56139436854855,-0.20822861329548692
10,6,188.13909142174575,-0.2662555763758106

@maniospas maniospas self-assigned this Mar 26, 2024
@maniospas
Copy link
Collaborator

Hi, based on a previous request, there's a full working example on graph classification with sort-pooling here: https://github.com/MKLab-ITI/JGNN/blob/main/JGNN/src/examples/graphClassification/SortPooling.java

For an even simpler example with mean pooling use this: https://github.com/MKLab-ITI/JGNN/blob/main/JGNN/src/examples/graphClassification/MeanPooling.java

Message passing for node classification would look like this (this is not a good architecture, but demonstrates all concepts): https://github.com/MKLab-ITI/JGNN/blob/main/JGNN/src/examples/nodeClassification/MessagePassing.java

Some explanations:
u=from(A) and v=to(A) get node indices for all edges, sot hat you can create a message m{l}_1=h{l}[u] | h{l}[v] whose rows correspond to edges. Notice the {l} that ties the message to the layer. The parser understands | as the concatenation operation, so the starting message basically has the pairs of input and output node representations for each edge. You can then transform these messages with normal operation and eventually call h{l} = reduce(m{l}_3, A) to obtain the average of edge representations m{l}_3, where the latter is some transformation of m{l}_1. In the reduction (i.e., the reduce method), each row of m{l}_3 corresponds to an edge, but for h{l} each row corresponds to a node.

However
Message passing is a very complicated (and in theory often not needed) operation that is inevitably much slower than other architectures. Therefore, for parsing large graphs I would caution against using JGN; just run some normal Python in the GPU because in that case you need a high-end computing solution. Given that you parse small graphs, however, simpler architectures can probably do the trick for you. You may instead want to create positional encodings if you need more expressive power.

To get a sense of the complexity, what the above-linked message passing architecture looks like can be found here. Recall that all these will happen in the CPU.

P.S. If you have a benchmark in mind, please provide each graph's edges and graph labels too.

@gavalian
Copy link
Author

Thanks for the examples, I tried several networks from the examples, which unfortunately are unable to learn the data set, I always get [0.5,0.5] in the output of the classifier. It may be that I'm loading the features and adjacency matrix wrong?

Let me try the SortPooling, and I will post the results.
As far as the performance is concerned, I have done some tests, and given the size of our graphs, the performance is satisfactory.

@maniospas
Copy link
Collaborator

Do let us know whether sort pooling ends up working for you, but sum pooling is also fine usually (in fact, any equivariant operation suffices for graph classification).
Depending on the architecture, you might want to normalize adjacency matrices, but for small graphs I doubt this is an issue.

In terms of debugging:

  • If your architecture runs a forward pass, you have the correct matrix sizes.
  • Just to be sure, please run architecture.assertBackwardValidity(); in your code. This will throw an exception in case you forgot to reuse previous computations to the output. If it doesn't throw an exception, your architecture is perfectly valid.
  • To check that you have data in your adjacencies, do System.out.println(dtrain.graphs.get(graphIdentifier).sum());. Similarly, you can check if you have data in your features and labels.
  • Make sure that you normalize graphs before the training loop (normalizing graphs again will further change their values).

If all these are ok, I think the only way you can be getting [0.5, 0.5] is because your features probably need some normalization, because they have very different value ranges. Especially the first feature may just be dominating computations to yield two very large but similar values in both outcome outputs.

Note: There is always a small chance that the problem you are trying to solve is not perfectly solvable with GNNs (contrary to traditional neural networks, GNNs can not approximate every graph function - simple architectures have only WL-1 or WL-2 expressive power, and universal approximation usually requires some form of computational unboundedness). However, even when not perfectly solvable, your architectures should be learning something. If all else fails, I recommend trying a simple GCN which I know from personal experience to always learn something, even in very hard tasks, despite not always being the best. (But do try normalizing your features first, because they are probably the issue here.)

P.S. If you want to, upload some code here so that I can take a look too.

@gavalian
Copy link
Author

Thanks for the suggestions.
I did, in fact, check the matrices, and it seems that I'm filling them right.
Sorry for the confusion:
When I put the sample data, it was before I normalized them to feed to the network. I'm currently running the sortPooling example, and it does indeed learn, however, the progress is very-very slow took me 1,600 iterations to get to 86%, I have regularization at 0.001 and a learning rate of 0.01. but it steadily grows.

The message-passing example deals with one graph only, if I understand this correctly. Is there a way to train it on many graphs?
Once I exhaust all my tests, I will post a code and data sample (I will probably need some help). This particular problem I'm trying to solve has been solved successfully with message message-passing method in Python, I want to see if I can get it working in Java. So I know the message-passing method with work with this particular dataset.

@maniospas
Copy link
Collaborator

maniospas commented Mar 26, 2024

You can just copy-paste the layer definitions. Here's an example:

        ModelBuilder builder = new LayeredBuilder()        
                .var("A")  
                .config("features", 1)
                .config("classes", 2)
                //.config("reduced", reduced) // only needed for sort pooling
                .config("hidden", hidden)
                .config("2hidden", 2*hidden)
                .config("reg", 0.005)
                .operation("edgeSrc = from(A)")
		.operation("edgeDst = to(A)")
		.layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))")
		.layer("h{l+1}=h{l}@matrix(hidden, hidden, reg)+vector(hidden)")
				
		// 1st part of message passing layer (make it as complex as needed): message transformation
		.operation("message{l}=h{l}[edgeSrc] | h{l}[edgeDst]")
		.operation("transformed{l}=relu(message{l}@matrix(2hidden, hidden, reg)+vector(hidden))")
 
		// 2nd part of message passing layer: receive the message, concatenate it with own representation and transform it
		.operation("received{l}=reduce(transformed{l}, A)")
		.operation("i{l}=relu((received{l} | h{l})@matrix(2hidden, hidden, reg)+vector(hidden))")
		.layer("h{l+1}=relu(i{l}@matrix(hidden, hidden, reg)+vector(hidden))")
				
		// the commented part would be sort pooling
                /*.config("hiddenReduced", hidden*reduced)  // reduced * (previous layer's output size)
                .operation("z{l}=sort(h{l}, reduced)")  // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable
                .layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
                .layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
                .layer("h{l+1}=softmax(h{l}, row)")*/
				
		// the following two layers implement sum pooling (transform to classes dimensions first)
                .layer("h{l+1}=sum(h{l}@matrix(hidden, classes)+vector(classes), row)")
                .layer("h{l+1}=softmax(h{l}, row)")
				
                .out("h{l}");

For the example it learns pretty fastly and accurately where the previous example wasn't performing so well (for this smaller dataset it reaches 97% accuracy before epoch 40):

TrajectoryData dtrain = new TrajectoryData(800);
TrajectoryData dtest = new TrajectoryData(200);

@gavalian
Copy link
Author

I tried the above example, but it's having a hard time training on my data, which is essentially a trajectory as well.
I put all the code and data files here:

https://drive.google.com/drive/folders/1KgtInDoiQ3WgztrW-itzohYfBFUAHNC-?usp=sharing

I have a class there that reads the data.

@gavalian
Copy link
Author

Hi Emmanouil,
Just to clarify my intentions with this study, if this works for the particular data set I sent you, this will be a peer-reviewed publication, where you would have full authorship rights. you can check out the publications.txt file on the link I sent you, to see what kind of work I do.

@maniospas
Copy link
Collaborator

Precondition: update to JGNN 1.2.0

Before running code from this reply, please update to the latest version.
I fixed a bug that was adding some small noise to computations (huge props for bug hunting). It wasn't noticeable in the large graphs I was tesing with because it did not impact performance by any noticeable amount there. However, it was throwing an error when there were more hidden layers than the number of nodes (if adjacencies were created correctly).

Fix

In terms of code, I don't understand why you are using a getIndex function (mostly because it was not included in your code).
I imagine what you meant to do is the following:

public static Matrix createAdjancencyMatrix(){
        Matrix m = new DenseMatrix(6,6);    
        for(int i = 0; i < 6; i++) 
             m.put(i, i, 1.0);        
        for(int i = 0; i < 5; i++) {
            m.put(i, i+1, 1.0);
            m.put(i+1, i, 1.0);
        }        
        return m;
}

After the fix, I tried your code and sort pooling seemed to be better:

        int reduced = 6; // equal to the number of nodes to not lose information (so sorting will just make a lossless permutation invariant layer - do not do this if you have a lot of nodes and just keep the number small, because it can be very computationally intensive in that case)
        int  hidden = 8;
        ModelBuilder builder = new LayeredBuilder()        
                .var("A")  
                .config("features", 2)
                .config("classes", 2)
                .config("reduced", reduced)
                .config("hidden", hidden)
                .config("2hidden", 2*hidden)
                .config("reg", 0.005)
                .operation("edgeSrc = from(A)")
                .operation("edgeDst = to(A)")
                .layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))")
                .layer("h{l+1}=h{l}@matrix(hidden, hidden, reg)+vector(hidden)")
                
                // 1st part of message passing layer (make it as complex as needed): message transformation
                .operation("message{l}=h{l}[edgeSrc] | h{l}[edgeDst]")
                .operation("transformed{l}=relu(message{l}@matrix(2hidden, hidden, reg)+vector(hidden))")
                
                // 2nd part of message passing layer: receive the message, concatenate it with own representation and transform it
                .operation("received{l}=reduce(transformed{l}, A)")
                .operation("i{l}=relu((received{l} | h{l})@matrix(2hidden, hidden, reg)+vector(hidden))")
                .layer("h{l+1}=relu(i{l}@matrix(hidden, hidden, reg)+vector(hidden))")
                
                // the commented part would be sort pooling
                .config("hiddenReduced", hidden*reduced)  // reduced * (previous layer's output size)
                .operation("z{l}=sort(h{l}, reduced)")  // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable
                .layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
                .layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
                .layer("h{l+1}=softmax(h{l}, row)")
				
                .out("h{l}");

Random suggestions

  1. I strongly recommend mini batching (run backprogation for a small subset of graphs) and computing accuracy once every fixed number of epochs because you have a lot of data. Learning rate 0.1 was also alright to speed up convergence, and absolutely run things in thread pools because they take advantage of multiple processors. This looks like this:
Model model = builder.getModel().init(new XavierNormal());
        BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.1));
        Loss loss = new CategoricalCrossEntropy();//.setMeanReduction(true);
        for(int epoch=0; epoch<300; epoch++) {
            // gradient update over all graphs
            for(int graphId=0; graphId<tr.features.size(); graphId++) {
            	if(Math.random()<0.1) // keep 10% of graphs in each iteration's batch
            		continue;
            	int graphIdentifier = graphId;
            	// each gradient calculation into a new thread pool task
            	ThreadPool.getInstance().submit(new Runnable() {
            		@Override
            		public void run() {
		                Matrix adjacency = tr.adjacency.get(graphIdentifier);
		                Matrix features= tr.features.get(graphIdentifier);
		                Tensor graphLabel = tr.labels.get(graphIdentifier).asRow();
		                model.train(loss, optimizer, 
		                        Arrays.asList(features, adjacency), 
		                        Arrays.asList(graphLabel));
		            }
            	});
            }
            ThreadPool.getInstance().waitForConclusion();  // wait for all gradients to compute
            optimizer.updateAll();  // apply gradients on model parameters
            
            // evaluate every 10 epochs only
            if(epoch%10!=0) 
            	continue;
            double acc = 0.0;
            for(int graphId=0; graphId<ts.adjacency.size(); graphId++) {
                Matrix adjacency = ts.adjacency.get(graphId);
                Matrix features= ts.features.get(graphId);
                Tensor graphLabel = ts.labels.get(graphId);
                if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
                    acc += 1;
            }
            System.out.println("iter = " + epoch + "  " + acc/ts.adjacency.size());
        }

This yields training set accuracy a little over 90% at epoch 300 for me. (I believe it would keep increasing.)

  1. Some helpful prints you could use for further assertions:
System.out.println(features.describe());
System.out.println(graphLabel.describe());
System.out.println(adjacency.describe());
System.out.println(adjacency); // actually prints the contents in the console
  1. If you are planning to make a publication I strongly advise writing the correct layers - these are just demos.
  2. Please don't forget to do a training-validation-test split for your data in final experiments (easy version: keep the test accuracy at the best validation accuracy, hard version copy parameters - open a different issue to request a tutorial for this if needed, because currently there isn't one I think).

Re: Authorship

Since I am not doing any actual research for which I could take responsibility, please just add me in the Acknowledgements ("The authors would like to thank Emmanouil Krasanakis from CERTH for..."). I would appreciate this a lot. Do sent me a version of your paper. Please cite JGNN's publication in the text:

Krasanakis, Emmanouil, Symeon Papadopoulos, and Ioannis Kompatsiaris. "JGNN: Graph Neural Networks on native Java." SoftwareX 23 (2023): 101459.

Thanks a lot for sharing your research directions. As someone working almost exclusively on enabling technologies it feels nice knowing that they do, in fact, have value in other scientific fields. :-)

@gavalian
Copy link
Author

gavalian commented Apr 2, 2024

This worked, and the network learned. However, I think I may be going at this the wrong way. Maybe you can suggest the right direction. the problem I'm trying to solve is the following.

I have a collection of N nodes (which I can make fixed say 128 nodes), I need to define reasonable connections between the nodes (edges) and then teach the network to assign weights to the connections. The minimum of 6 nodes have to be present to make a valid collection and maximum of 12 nodes. Is there an architecture that can work for this kind of problem?

As of now I do not have a paper draft yet, I'm doing the feasibility studies, but if this works, it will turn into a paper (I'll definitely cite your work and acknowledgement).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants