diff --git a/src/main/java/Log4jHotPatch.java b/src/main/java/Log4jHotPatch.java index 833a1bd..160848b 100644 --- a/src/main/java/Log4jHotPatch.java +++ b/src/main/java/Log4jHotPatch.java @@ -14,10 +14,7 @@ */ import java.io.BufferedReader; -import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileOutputStream; -import java.io.InputStream; import java.io.InputStreamReader; import java.lang.instrument.ClassFileTransformer; import java.lang.instrument.Instrumentation; @@ -27,10 +24,6 @@ import java.security.ProtectionDomain; import java.util.Properties; import java.util.Set; -import java.util.jar.Attributes; -import java.util.jar.JarEntry; -import java.util.jar.JarOutputStream; -import java.util.jar.Manifest; import com.sun.tools.attach.VirtualMachine; import sun.jvmstat.monitor.MonitoredHost; @@ -47,7 +40,7 @@ public class Log4jHotPatch { // version of this agent - private static final int log4jFixerAgentVersion = 1; + private static final int log4jFixerAgentVersion = 2; // property name for verbose flag public static final String LOG4J_FIXER_VERBOSE = "log4jFixerVerbose"; @@ -111,20 +104,38 @@ public static void agentmain(String args, Instrumentation inst) { ClassFileTransformer transformer = new ClassFileTransformer() { - public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, - ProtectionDomain protectionDomain, byte[] classfileBuffer) { - if (className != null && className.endsWith("org/apache/logging/log4j/core/lookup/JndiLookup")) { - log("Transforming " + className + " (" + loader + ")"); - ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); - MethodInstrumentorClassVisitor cv = new MethodInstrumentorClassVisitor(asm, cw); - ClassReader cr = new ClassReader(classfileBuffer); - cr.accept(cv, 0); - return cw.toByteArray(); - } else { - return null; - } + public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, + ProtectionDomain protectionDomain, byte[] classfileBuffer) { + if (className != null && className.endsWith("org/apache/logging/log4j/core/lookup/JndiLookup")) { + log("Transforming " + className + " (" + loader + ")"); + ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + MethodInstrumentorClassVisitor cv = new MethodInstrumentorClassVisitor(asm, cw); + ClassReader cr = new ClassReader(classfileBuffer); + cr.accept(cv, 0); + + return cw.toByteArray(); + } else if (className != null && (className.endsWith("org/apache/log4j/net/SocketServer") || className.endsWith("org/apache/log4j/net/SimpleSocketServer"))) { + log("Transforming " + className + " (" + loader + ")"); + ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + SocketServerMethodInstrumentorClassVisitor sscv = new SocketServerMethodInstrumentorClassVisitor(asm, cw); + ClassReader cr = new ClassReader(classfileBuffer); + cr.accept(sscv, 0); + return cw.toByteArray(); + } else if (className != null && className.endsWith("org/apache/log4j/net/JMSAppender")) { + log("Transforming " + className + " (" + loader + ")"); + ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + JMSAppenderMethodInstrumentorClassVisitor jmcv = new JMSAppenderMethodInstrumentorClassVisitor(asm, cw); + ClassReader cr = new ClassReader(classfileBuffer); + + cr.accept(jmcv, 0); + return cw.toByteArray(); + } else { + return null; } - }; + + + } + }; if (!staticAgent) { int patchesApplied = 0; @@ -133,7 +144,11 @@ public byte[] transform(ClassLoader loader, String className, Class classBein for (Class c : inst.getAllLoadedClasses()) { String className = c.getName(); - if (className.endsWith("org.apache.logging.log4j.core.lookup.JndiLookup")) { + if (className.endsWith("org.apache.logging.log4j.core.lookup.JndiLookup") || + className.endsWith("org.apache.log4j.net.SocketServer") || + className.endsWith("org.apache.log4j.net.SimpleSocketServer") || + className.endsWith("org.apache.log4j.net.JMSAppender") + ) { log("Patching " + c + " (" + c.getClassLoader() + ")"); try { inst.retransformClasses(c); @@ -183,6 +198,84 @@ public MethodVisitor visitMethod(int access, String name, String desc, String si } } + static class SocketServerMethodInstrumentorClassVisitor extends ClassVisitor { + private int asm; + + public SocketServerMethodInstrumentorClassVisitor(int asm, ClassVisitor cv) { + super(asm, cv); + this.asm = asm; + } + + @Override + public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) { + MethodVisitor mv = cv.visitMethod(access, name, desc, signature, exceptions); + if ("init".equals(name) || "main".equals(name)) { + mv = new EmptyVoidMethodInstrumentorMethodVisitor(asm, mv); + } + return mv; + } + } + + static class JMSAppenderMethodInstrumentorClassVisitor extends ClassVisitor { + private int asm; + + public JMSAppenderMethodInstrumentorClassVisitor(int asm, ClassVisitor cv) { + super(asm, cv); + this.asm = asm; + } + + @Override + public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) { + MethodVisitor mv = cv.visitMethod(access, name, desc, signature, exceptions); + if ("activateOptions".equals(name) || "append".equals(name) || "close".equals(name)) { + mv = new EmptyVoidMethodInstrumentorMethodVisitor(asm, mv); + } else if ("lookup".equals(name)) { + mv = new EmptyObjectMethodInstrumentorMethodVisitor(asm, mv); + } else if ("checkEntryConditions".equals(name)) { + mv = new EmptyBooleanMethodInstrumentorMethodVisitor(asm, mv); + } + return mv; + } + } + + static class EmptyVoidMethodInstrumentorMethodVisitor extends MethodVisitor implements Opcodes { + + public EmptyVoidMethodInstrumentorMethodVisitor(int asm, MethodVisitor mv) { + super(asm, mv); + } + + @Override + public void visitCode() { + mv.visitInsn(RETURN); + } + } + + static class EmptyObjectMethodInstrumentorMethodVisitor extends MethodVisitor implements Opcodes { + + public EmptyObjectMethodInstrumentorMethodVisitor(int asm, MethodVisitor mv) { + super(asm, mv); + } + + @Override + public void visitCode() { + mv.visitInsn(ACONST_NULL); + mv.visitInsn(ARETURN); + } + } + + static class EmptyBooleanMethodInstrumentorMethodVisitor extends MethodVisitor implements Opcodes { + + public EmptyBooleanMethodInstrumentorMethodVisitor(int asm, MethodVisitor mv) { + super(asm, mv); + } + + @Override + public void visitCode() { + mv.visitInsn(ICONST_1); + mv.visitInsn(IRETURN); + } + } + static class MethodInstrumentorMethodVisitor extends MethodVisitor implements Opcodes { public MethodInstrumentorMethodVisitor(int asm, MethodVisitor mv) { @@ -253,8 +346,8 @@ private static boolean loadInstrumentationAgent(String[] pids) throws Exception private static String getUID(String pid) { try { return Files.lines(FileSystems.getDefault().getPath("/proc/" + pid + "/status")). - filter(l -> l.startsWith("Uid:")). - findFirst().get().split("\\s")[1]; + filter(l -> l.startsWith("Uid:")). + findFirst().get().split("\\s")[1]; } catch (Exception e) { return null; }