Skip to content

Commit

Permalink
Add gRPC Memshell Check
Browse files Browse the repository at this point in the history
  • Loading branch information
sf197 committed Dec 26, 2022
1 parent dad9053 commit 3d0f269
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 8 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ vmObj.loadAgent(agentJarPath, cfg); // agentJarPath为MemoryShellHunter jar包

### Supported middleware

1.2 Version:

- Add gRPC memory shell check algorithm

1.1 Version:

- Add Controller memory shell check algorithm
Expand All @@ -51,3 +55,9 @@ vmObj.loadAgent(agentJarPath, cfg); // agentJarPath为MemoryShellHunter jar包

##### Controller Memory Shell Test Report
![controller](./images/controller.png)



##### gRPC Memory Shell Test Report

![controller](./images/grpc.png)
Binary file added images/grpc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 8 additions & 1 deletion src/main/java/com/websocket/findMemShell/App.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand All @@ -25,6 +26,7 @@ public class App {
public static final String Change_Class = "org/apache/catalina/core/ApplicationContext";
public static final String Change_Class_Method = "<init>";
public static final String Change_Class_Method_Desc = "(Lorg/apache/catalina/core/StandardContext;)V";
public static List<String> Grpc_Methods_list = new ArrayList<>();
public static int count = 0;
public static Object servletContext = null;

Expand Down Expand Up @@ -60,7 +62,7 @@ private static void loadAgent(String arg, final Instrumentation inst) {
}
}

SearchCallsThread thread = new SearchCallsThread(discoveredCalls);
SearchCallsThread thread = new SearchCallsThread(discoveredCalls,Grpc_Methods_list);
thread.start();
System.out.println("Done!");
}
Expand All @@ -83,6 +85,11 @@ public byte[] transform(ClassLoader classLoader, String s, Class<?> aClass, Prot
//System.out.println("Dumping end ...");

return writer.toByteArray();
}else{
ClassReader reader = new ClassReader(bytes);
ClassWriter writer = new ClassWriter(reader, 0);
GrpcClassVisitor visitor = new GrpcClassVisitor(writer,Grpc_Methods_list);
reader.accept(visitor, 0);
}

ClassReader reader = new ClassReader(bytes);
Expand Down
72 changes: 72 additions & 0 deletions src/main/java/com/websocket/findMemShell/GrpcClassVisitor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package com.websocket.findMemShell;

import java.util.List;

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;


public class GrpcClassVisitor extends ClassVisitor {

private String ClassName = null;
private List<String> Grpc_Methods_list;

public GrpcClassVisitor(ClassWriter writer,List<String> Grpc_Methods_list) {
super(Opcodes.ASM4, writer);
this.Grpc_Methods_list = Grpc_Methods_list;
}

@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
if(superName.contains("ServiceGrpc")) {
try {
String cls = Thread.currentThread().getContextClassLoader().loadClass(superName.replaceAll("/", "\\.")).getInterfaces()[0].getName();
if(cls.equals("io.grpc.BindableService")) {
//System.out.println("SuperName Class:"+cls);
this.ClassName = name;
}

} catch (ClassNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
super.visit(version, access, name, signature, superName, interfaces);
}

@Override
public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
MethodVisitor methodVisitor = cv.visitMethod(access, name, desc, signature, exceptions);
if(this.ClassName == null) {
return methodVisitor;
}else {
return new MyMethodVisitor(methodVisitor, access, name, desc,this.ClassName,this.Grpc_Methods_list);
}

}

class MyMethodVisitor extends MethodVisitor implements Opcodes {
private String MethodName;
private String ClassName;
private List<String> Grpc_Methods_list;
public MyMethodVisitor(MethodVisitor mv, final int access, final String name, final String desc,String ClassName,List<String> Grpc_Methods_list) {
super(Opcodes.ASM5, mv);
this.MethodName = name;
this.ClassName = ClassName;
this.Grpc_Methods_list = Grpc_Methods_list;
}

@Override
public void visitMethodInsn(final int opcode, final String owner,
final String name, final String desc, final boolean itf) {

if(!this.Grpc_Methods_list.contains(this.ClassName+"#"+this.MethodName)) {
this.Grpc_Methods_list.add(this.ClassName+"#"+this.MethodName);
//System.out.println(this.ClassName+"#"+this.MethodName);
}
super.visitMethodInsn(opcode, owner, name, desc, itf);
}
}
}
17 changes: 15 additions & 2 deletions src/main/java/com/websocket/findMemShell/SearchCallsThread.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

public class SearchCallsThread extends Thread{
Map<String,List<String>> discoveredCalls;
List<String> Grpc_Methods_list = null;
String sinkMethod = "java/lang/Runtime#exec";
Stack stack = new Stack();
int count = 0;
List<String> visitedClass = new ArrayList<>();

public SearchCallsThread(Map<String,List<String>> discoveredCalls) {
public SearchCallsThread(Map<String,List<String>> discoveredCalls,List<String> Grpc_Methods_list) {
this.discoveredCalls = discoveredCalls;
this.Grpc_Methods_list = Grpc_Methods_list;
}

public void checkWsConfig(ConfigPath cp) {
Expand Down Expand Up @@ -77,11 +79,22 @@ public void checkControllerPath(ConfigPath cp) {
}
}

public void addGrpcListToConfigPath(List<ConfigPath> result) {
for(String cls : this.Grpc_Methods_list) {
ConfigPath cp = new ConfigPath("/Grpc",cls);
result.add(cp);
}
}

@Override
public void run() {
while(true) {
List<ConfigPath> result = getWsConfigResult.getWsConfig();
List<ConfigPath> result = new ArrayList<>();
getWsConfigResult.getWsConfig(result);
getControllerResult.getControllerMemShell(result);
if(this.Grpc_Methods_list != null) {
addGrpcListToConfigPath(result);
}
if(result != null && result.size() != 0) {
for(ConfigPath cp : result) {
if(!cp.getClassName().contains("#")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ public static boolean deleteConfig(String className) {
}
}

public static List<ConfigPath> getWsConfig() {
public static void getWsConfig(List<ConfigPath> classList) {
try {
Object servletContext = App.servletContext;
if(servletContext == null) {
return null;
return;
}
List<ConfigPath> classList = new ArrayList<>();
//System.out.println("servletContext ClassLoader: "+servletContext.getClass().getClassLoader());
Method getAttribute = servletContext.getClass().getClassLoader().loadClass("org.apache.catalina.core.ApplicationContextFacade").getDeclaredMethod("getAttribute", String.class);
Object wsServerContainer = getAttribute.invoke(servletContext, "javax.websocket.server.ServerContainer");
Expand All @@ -68,10 +67,8 @@ public static List<ConfigPath> getWsConfig() {
ConfigPath cp = new ConfigPath(key,clazz.getName());
classList.add(cp);
}
return classList;
}catch(Exception e) {
e.printStackTrace();
}
return null;
}
}

0 comments on commit 3d0f269

Please sign in to comment.