From 264664e369c742004c0999e684e4b33d58e5e57a Mon Sep 17 00:00:00 2001 From: Dmitry Schitinin Date: Thu, 12 Jan 2023 19:28:24 +0300 Subject: [PATCH] Fixes BigDecimal deserialization (#67) --- src/one/nio/serial/GeneratedSerializer.java | 5 +- src/one/nio/serial/Repository.java | 3 + src/one/nio/serial/gen/DelegateGenerator.java | 66 ++++++-- .../nio/serial/gen/GetFieldInputStream.java | 149 ++++++++++++++++++ .../nio/serial/gen/NullObjectInputStream.java | 3 +- test/one/nio/serial/ConversionTest.java | 5 +- test/one/nio/serial/DefaultFieldsTest.java | 5 +- test/one/nio/serial/SerializationTest.java | 28 +++- 8 files changed, 231 insertions(+), 33 deletions(-) create mode 100644 src/one/nio/serial/gen/GetFieldInputStream.java diff --git a/src/one/nio/serial/GeneratedSerializer.java b/src/one/nio/serial/GeneratedSerializer.java index c3bbcad..2acc362 100755 --- a/src/one/nio/serial/GeneratedSerializer.java +++ b/src/one/nio/serial/GeneratedSerializer.java @@ -16,7 +16,6 @@ package one.nio.serial; -import one.nio.gen.BytecodeGenerator; import one.nio.serial.gen.Delegate; import one.nio.serial.gen.DelegateGenerator; import one.nio.serial.gen.StubGenerator; @@ -55,7 +54,7 @@ public class GeneratedSerializer extends Serializer { this.defaultFields = new FieldDescriptor[0]; checkFieldTypes(); - this.delegate = BytecodeGenerator.INSTANCE.instantiate(code(), Delegate.class); + this.delegate = DelegateGenerator.instantiate(cls, fds, code()); } @Override @@ -93,7 +92,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept this.defaultFields = assignDefaultFields(ownFields); checkFieldTypes(); - this.delegate = BytecodeGenerator.INSTANCE.instantiate(code(), Delegate.class); + this.delegate = DelegateGenerator.instantiate(cls, fds, code()); } @Override diff --git a/src/one/nio/serial/Repository.java b/src/one/nio/serial/Repository.java index d5248be..5e1f946 100755 --- a/src/one/nio/serial/Repository.java +++ b/src/one/nio/serial/Repository.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.io.Serializable; import java.lang.reflect.Method; +import java.math.BigDecimal; import java.math.BigInteger; import java.net.*; import java.nio.file.Files; @@ -62,6 +63,7 @@ public class Repository { public static final int INLINE = 4; public static final int FIELD_SERIALIZATION = 8; public static final int SYNTHETIC_FIELDS = 16; + public static final int PROVIDE_GET_FIELD = 32; public static final int ARRAY_STUBS = 1; public static final int COLLECTION_STUBS = 2; @@ -150,6 +152,7 @@ public class Repository { setOptions(StringBuilder.class, SKIP_CUSTOM_SERIALIZATION); setOptions(StringBuffer.class, SKIP_CUSTOM_SERIALIZATION); setOptions(BigInteger.class, SKIP_CUSTOM_SERIALIZATION); + setOptions(BigDecimal.class, PROVIDE_GET_FIELD); // At some moment InetAddress fields were moved to an auxilary holder class. // This resolves backward compatibility problem by inlining holder fields during serialization. diff --git a/src/one/nio/serial/gen/DelegateGenerator.java b/src/one/nio/serial/gen/DelegateGenerator.java index ce48de3..7881492 100755 --- a/src/one/nio/serial/gen/DelegateGenerator.java +++ b/src/one/nio/serial/gen/DelegateGenerator.java @@ -46,6 +46,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -90,6 +91,31 @@ private static void defineBootstrapClass(Method m, String classData) throws Refl m.invoke(null, null, null, code, 0, code.length, null, null); } + public static Delegate instantiate(Class cls, FieldDescriptor[] fds, byte[] code) { + Map fieldsMap = null; + if (Repository.hasOptions(cls, Repository.PROVIDE_GET_FIELD)) { + fieldsMap = new HashMap<>(fds.length, 1); + for (FieldDescriptor fd : fds) { + Field field = fd.ownField(); + if (field != null) { + fieldsMap.put(field.getName(), field); + JavaInternals.setAccessible(field); + } + } + } + try { + return (Delegate) BytecodeGenerator.INSTANCE.defineClass(code) + .getDeclaredConstructor(Map.class) + .newInstance(fieldsMap); + } catch (Exception e) { + throw new IllegalArgumentException("Cannot instantiate class", e); + } + } + + public static Delegate instantiate(Class cls, FieldDescriptor[] fds, FieldDescriptor[] defaultFields) { + return instantiate(cls, fds, generate(cls, fds, defaultFields)); + } + public static byte[] generate(Class cls, FieldDescriptor[] fds, FieldDescriptor[] defaultFields) { String className = "sun/reflect/Delegate" + index.getAndIncrement() + '_' + cls.getSimpleName(); @@ -97,25 +123,31 @@ public static byte[] generate(Class cls, FieldDescriptor[] fds, FieldDescriptor[ cv.visit(V1_6, ACC_PUBLIC | ACC_FINAL, className, null, MAGIC_CLASS, new String[]{"one/nio/serial/gen/Delegate"}); - generateConstructor(cv); + generateConstructor(cv, className); generateCalcSize(cv, cls, fds); generateWrite(cv, cls, fds); - generateRead(cv, cls, fds, defaultFields); + generateRead(cv, cls, fds, defaultFields, className); generateSkip(cv, fds); generateToJson(cv, cls, fds); - generateFromJson(cv, cls, fds, defaultFields); + generateFromJson(cv, cls, fds, defaultFields, className); cv.visitEnd(); return cv.toByteArray(); } - private static void generateConstructor(ClassVisitor cv) { - MethodVisitor mv = cv.visitMethod(ACC_PUBLIC, "", "()V", null, null); + private static void generateConstructor(ClassVisitor cv, String className) { + MethodVisitor mv = cv.visitMethod(ACC_PUBLIC, "", "(Ljava/util/Map;)V", null, null); + cv.visitField(ACC_PRIVATE | ACC_FINAL, "fields", "Ljava/util/Map;", null, null).visitEnd(); + mv.visitCode(); mv.visitVarInsn(ALOAD, 0); mv.visitMethodInsn(INVOKESPECIAL, MAGIC_CLASS, "", "()V", false); + mv.visitVarInsn(ALOAD, 0); + mv.visitVarInsn(ALOAD, 1); + mv.visitFieldInsn(PUTFIELD, className, "fields", "Ljava/util/Map;"); + mv.visitInsn(RETURN); mv.visitMaxs(0, 0); mv.visitEnd(); @@ -212,7 +244,7 @@ private static void emitWriteObject(Class cls, MethodVisitor mv) { } } - private static void generateRead(ClassVisitor cv, Class cls, FieldDescriptor[] fds, FieldDescriptor[] defaultFields) { + private static void generateRead(ClassVisitor cv, Class cls, FieldDescriptor[] fds, FieldDescriptor[] defaultFields, String className) { MethodVisitor mv = cv.visitMethod(ACC_PUBLIC | ACC_FINAL, "read", "(Lone/nio/serial/DataStream;)Ljava/lang/Object;", null, new String[]{"java/io/IOException", "java/lang/ClassNotFoundException"}); mv.visitCode(); @@ -261,20 +293,30 @@ private static void generateRead(ClassVisitor cv, Class cls, FieldDescriptor[] f if (isRecord) { generateCreateRecord(mv, cls, fds, defaultFields); } - - emitReadObject(cls, mv); + + emitReadObject(cls, mv, className); mv.visitInsn(ARETURN); mv.visitMaxs(0, 0); mv.visitEnd(); } - private static void emitReadObject(Class cls, MethodVisitor mv) { + private static void emitReadObject(Class cls, MethodVisitor mv, String className) { MethodType methodType = MethodType.methodType(void.class, ObjectInputStream.class); MethodHandleInfo m = MethodHandlesReflection.findInstanceMethod(cls, "readObject", methodType); if (m != null && !Repository.hasOptions(m.getDeclaringClass(), Repository.SKIP_READ_OBJECT)) { mv.visitInsn(DUP); - mv.visitFieldInsn(GETSTATIC, "one/nio/serial/gen/NullObjectInputStream", "INSTANCE", "Lone/nio/serial/gen/NullObjectInputStream;"); + if (!Repository.hasOptions(m.getDeclaringClass(), Repository.PROVIDE_GET_FIELD)) { + mv.visitFieldInsn(GETSTATIC, "one/nio/serial/gen/NullObjectInputStream", "INSTANCE", "Lone/nio/serial/gen/NullObjectInputStream;"); + } else { + mv.visitInsn(DUP); + mv.visitTypeInsn(NEW, "one/nio/serial/gen/GetFieldInputStream"); + mv.visitInsn(DUP_X1); + mv.visitInsn(SWAP); + mv.visitVarInsn(ALOAD, 0); + mv.visitFieldInsn(GETFIELD, className, "fields", "Ljava/util/Map;"); + mv.visitMethodInsn(INVOKESPECIAL, "one/nio/serial/gen/GetFieldInputStream", "", "(Ljava/lang/Object;Ljava/util/Map;)V", false); + } emitInvoke(mv, m); } } @@ -377,7 +419,7 @@ private static void generateToJson(ClassVisitor cv, Class cls, FieldDescriptor[] mv.visitEnd(); } - private static void generateFromJson(ClassVisitor cv, Class cls, FieldDescriptor[] fds, FieldDescriptor[] defaultFields) { + private static void generateFromJson(ClassVisitor cv, Class cls, FieldDescriptor[] fds, FieldDescriptor[] defaultFields, String className) { MethodVisitor mv = cv.visitMethod(ACC_PUBLIC | ACC_FINAL, "fromJson", "(Lone/nio/serial/JsonReader;)Ljava/lang/Object;", null, new String[]{"java/io/IOException", "java/lang/ClassNotFoundException"}); mv.visitCode(); @@ -507,7 +549,7 @@ private static void generateFromJson(ClassVisitor cv, Class cls, FieldDescriptor generateCreateRecord(mv, cls, fds, defaultFields); } - emitReadObject(cls, mv); + emitReadObject(cls, mv, className); mv.visitInsn(ARETURN); mv.visitMaxs(0, 0); diff --git a/src/one/nio/serial/gen/GetFieldInputStream.java b/src/one/nio/serial/gen/GetFieldInputStream.java new file mode 100644 index 0000000..689985f --- /dev/null +++ b/src/one/nio/serial/gen/GetFieldInputStream.java @@ -0,0 +1,149 @@ +/* + * Copyright 2022 Odnoklassniki Ltd, Mail.Ru Group + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.serial.gen; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.lang.reflect.Field; +import java.util.Map; + +class GetFieldInputStream extends NullObjectInputStream { + + private final Map fields; + private final Object source; + + GetFieldInputStream(Object source, Map fields) throws IOException, SecurityException { + this.fields = fields; + this.source = source; + } + + @Override + public GetField readFields() { + return new ObjectGetField(fields, source); + } + + private static class ObjectGetField extends ObjectInputStream.GetField { + private final Object object; + private final Map fields; + + private ObjectGetField(Map fields, Object object) { + this.object = object; + this.fields = fields; + } + + @Override + public ObjectStreamClass getObjectStreamClass() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean defaulted(String name) { + return !fields.containsKey(name); + } + + @Override + public boolean get(String name, boolean val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.getBoolean(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + + @Override + public byte get(String name, byte val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.getByte(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + + @Override + public char get(String name, char val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.getChar(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + + @Override + public short get(String name, short val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.getShort(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + + @Override + public int get(String name, int val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.getInt(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + + @Override + public long get(String name, long val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.getLong(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + + @Override + public float get(String name, float val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.getFloat(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + + @Override + public double get(String name, double val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.getDouble(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + + @Override + public Object get(String name, Object val) throws IOException { + try { + Field field = fields.get(name); + return field != null ? field.get(object) : val; + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + } +} diff --git a/src/one/nio/serial/gen/NullObjectInputStream.java b/src/one/nio/serial/gen/NullObjectInputStream.java index c6d4377..d376b36 100755 --- a/src/one/nio/serial/gen/NullObjectInputStream.java +++ b/src/one/nio/serial/gen/NullObjectInputStream.java @@ -31,8 +31,7 @@ public class NullObjectInputStream extends ObjectInputStream { } } - private NullObjectInputStream() throws IOException { - // Singleton + protected NullObjectInputStream() throws IOException { } @Override diff --git a/test/one/nio/serial/ConversionTest.java b/test/one/nio/serial/ConversionTest.java index 2c99e22..a42c13a 100644 --- a/test/one/nio/serial/ConversionTest.java +++ b/test/one/nio/serial/ConversionTest.java @@ -16,7 +16,6 @@ package one.nio.serial; -import one.nio.gen.BytecodeGenerator; import one.nio.serial.gen.Delegate; import one.nio.serial.gen.DelegateGenerator; import org.junit.Test; @@ -32,13 +31,11 @@ public class ConversionTest implements Serializable { @Test public void testFieldConversion() throws Exception { - byte[] code = DelegateGenerator.generate(ConversionTest.class, new FieldDescriptor[]{ + Delegate delegate = DelegateGenerator.instantiate(ConversionTest.class, new FieldDescriptor[]{ fd("intField", BigInteger.class), fd("longField", BigInteger.class) }, new FieldDescriptor[0]); - Delegate delegate = BytecodeGenerator.INSTANCE.instantiate(code, Delegate.class); - byte[] data = new byte[100]; delegate.write(new ConversionTest(), new DataStream(data)); ConversionTest clone = (ConversionTest) delegate.read(new DataStream(data)); diff --git a/test/one/nio/serial/DefaultFieldsTest.java b/test/one/nio/serial/DefaultFieldsTest.java index b113414..057bf0a 100755 --- a/test/one/nio/serial/DefaultFieldsTest.java +++ b/test/one/nio/serial/DefaultFieldsTest.java @@ -16,7 +16,6 @@ package one.nio.serial; -import one.nio.gen.BytecodeGenerator; import one.nio.serial.gen.Delegate; import one.nio.serial.gen.DelegateGenerator; import org.junit.Test; @@ -24,7 +23,6 @@ import java.io.Serializable; import java.lang.annotation.ElementType; import java.lang.reflect.Field; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Set; @@ -63,9 +61,8 @@ public void testDefaultFields() throws Exception { for (int i = 0; i < ownFields.length; i++) { defaultFields[i] = new FieldDescriptor(ownFields[i], null, i); } - byte[] code = DelegateGenerator.generate(DefaultFieldsTest.class, new FieldDescriptor[0], defaultFields); - Delegate delegate = BytecodeGenerator.INSTANCE.instantiate(code, Delegate.class); + Delegate delegate = DelegateGenerator.instantiate(DefaultFieldsTest.class, new FieldDescriptor[0], defaultFields); DefaultFieldsTest obj = (DefaultFieldsTest) delegate.read(new DataStream(0)); assertEquals("abc", obj.s); diff --git a/test/one/nio/serial/SerializationTest.java b/test/one/nio/serial/SerializationTest.java index 7515007..230fd9f 100755 --- a/test/one/nio/serial/SerializationTest.java +++ b/test/one/nio/serial/SerializationTest.java @@ -286,6 +286,26 @@ public void testExceptions() throws IOException, ClassNotFoundException { assertEquals(css1.count(), css2.count()); } + @Test + public void testBigInteger() throws IOException, ClassNotFoundException { + checkSerialize(new BigInteger("12345678901234567890")); + checkSerialize(new BigInteger(-1, new byte[]{11, 22, 33, 44, 55, 66, 77, 88, 99})); + } + + @Test + public void testBigDecimal() throws IOException, ClassNotFoundException { + checkSerialize(new BigDecimal("999.999999999")); + checkSerialize(new BigDecimal(999.999999999)); + checkSerialize(new BigDecimal("88888888888888888.88888888888888888888888")); + checkSerialize(new BigDecimal("88888888888888888.88888888888888888888888")); + checkSerialize(new BigDecimal("12.3E+7")); + checkSerialize(Arrays.asList( + new BigDecimal("88888888888888888.88888888888888888888888"), + new BigDecimal("1"), + new BigDecimal(1), + new BigDecimal(0))); + } + @Test public void testInetAddress() throws IOException, ClassNotFoundException { checkSerialize(InetAddress.getByName("123.45.67.89")); @@ -298,14 +318,6 @@ public void testInetAddress() throws IOException, ClassNotFoundException { checkSerialize(new InetSocketAddress("google.com", 443)); } - @Test - public void testBigDecimal() throws IOException, ClassNotFoundException { - checkSerialize(new BigInteger("12345678901234567890")); - checkSerialize(new BigInteger(-1, new byte[]{11, 22, 33, 44, 55, 66, 77, 88, 99})); - checkSerialize(new BigDecimal(999.999999999)); - checkSerialize(new BigDecimal("88888888888888888.88888888888888888888888")); - } - @Test public void testStringBuilder() throws IOException, ClassNotFoundException { checkSerializeToString(new StringBuilder());