From c27eea8584b19cfd1a99b8561a625341335910b0 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Mon, 18 Mar 2024 01:36:32 +0200 Subject: [PATCH 01/11] Refactor TornadoHelper in unit tests for optimization. --- .../unittests/tools/TornadoHelper.java | 46 ++++++++----------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tools/TornadoHelper.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tools/TornadoHelper.java index 8566d6dff3..b476e22ae2 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tools/TornadoHelper.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tools/TornadoHelper.java @@ -18,7 +18,6 @@ package uk.ac.manchester.tornado.unittests.tools; -import static org.junit.platform.engine.discovery.DiscoverySelectors.selectClass; import static org.junit.platform.engine.discovery.DiscoverySelectors.selectMethod; import java.io.BufferedWriter; @@ -29,6 +28,7 @@ import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; +import java.util.Arrays; import java.util.Date; import java.util.HashSet; @@ -55,9 +55,7 @@ public class TornadoHelper { public static final boolean OPTIMIZE_LOAD_STORE_SPIRV = Boolean.parseBoolean(System.getProperty("tornado.spirv.loadstore", "False")); - // private static void printResult(Result result) { - // System.out.printf("Test ran: %s, Failed: %s%n", result.getRunCount(), result.getFailureCount()); - // } + public static final String SEPERATOR = " ................ "; private static void printResult(SummaryGeneratingListener result) { System.out.printf("Test ran: %s, Failed: %s%n", result.getSummary().getTestsStartedCount(), result.getSummary().getTestsFailedCount()); @@ -152,18 +150,14 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF bufferFile.append(message); if (suite != null && suite.unsupportedMethods.contains(m)) { - message = String.format("%20s", " ................ " + ColorsTerminal.YELLOW + " [NOT VALID TEST: UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.YELLOW + " [NOT VALID TEST: UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); notSupported++; continue; } - LauncherDiscoveryRequest request = LauncherDiscoveryRequestBuilder.request().selectors(selectClass(klass)).build(); - // Request request = Request.method(klass, m.getName()); - // Result result = new JUnitCore().run(request); - - // Create a Launcher and register a listener if needed (e.g., to generate a summary) + LauncherDiscoveryRequest request = LauncherDiscoveryRequestBuilder.request().selectors(selectMethod(klass, m)).build(); Launcher launcher = LauncherFactory.create(); SummaryGeneratingListener listener = new SummaryGeneratingListener(); launcher.registerTestExecutionListeners(listener); @@ -172,15 +166,14 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF launcher.execute(request); if (listener.getSummary().getTestsFailedCount() == 0) { - message = String.format("%20s", " ................ " + ColorsTerminal.GREEN + " [PASS] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.GREEN + " [PASS] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); successCounter++; } else { - // If UnsupportedConfigurationException is thrown this means that test did not - // fail, it simply can't be run on current configuration + // If UnsupportedConfigurationException is thrown this means that test did not fail, it simply can't be run on current configuration if (listener.getSummary().getFailures().stream().filter(e -> (e.getException() instanceof UnsupportedConfigurationException)).count() > 0) { - message = String.format("%20s", " ................ " + ColorsTerminal.PURPLE + " [UNSUPPORTED CONFIGURATION: At least 2 accelerators are required] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.PURPLE + " [UNSUPPORTED CONFIGURATION: At least 2 accelerators are required] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); notSupported++; @@ -188,7 +181,7 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF } if (listener.getSummary().getFailures().stream().anyMatch(e -> (e.getException() instanceof TornadoVMPTXNotSupported))) { - message = String.format("%20s", " ................ " + ColorsTerminal.PURPLE + " [PTX CONFIGURATION UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.PURPLE + " [PTX CONFIGURATION UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); notSupported++; @@ -196,7 +189,7 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF } if (listener.getSummary().getFailures().stream().anyMatch(e -> (e.getException() instanceof TornadoNoOpenCLPlatformException))) { - message = String.format("%20s", " ................ " + ColorsTerminal.PURPLE + " [OPENCL CONFIGURATION UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.PURPLE + " [OPENCL CONFIGURATION UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); notSupported++; @@ -204,7 +197,7 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF } if (listener.getSummary().getFailures().stream().anyMatch(e -> (e.getException() instanceof TornadoVMMultiDeviceNotSupported))) { - message = String.format("%20s", " ................ " + ColorsTerminal.PURPLE + " [[UNSUPPORTED] MULTI-DEVICE CONFIGURATION REQUIRED] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.PURPLE + " [[UNSUPPORTED] MULTI-DEVICE CONFIGURATION REQUIRED] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); notSupported++; @@ -212,7 +205,7 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF } if (listener.getSummary().getFailures().stream().anyMatch(e -> (e.getException() instanceof TornadoVMOpenCLNotSupported))) { - message = String.format("%20s", " ................ " + ColorsTerminal.PURPLE + " [OPENCL CONFIGURATION UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.PURPLE + " [OPENCL CONFIGURATION UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); notSupported++; @@ -220,7 +213,7 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF } if (listener.getSummary().getFailures().stream().anyMatch(e -> (e.getException() instanceof TornadoVMSPIRVNotSupported))) { - message = String.format("%20s", " ................ " + ColorsTerminal.PURPLE + " [SPIRV CONFIGURATION UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.PURPLE + " [SPIRV CONFIGURATION UNSUPPORTED] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); notSupported++; @@ -228,7 +221,7 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF } if (listener.getSummary().getFailures().stream().anyMatch(e -> (e.getException() instanceof SPIRVOptNotSupported)) && OPTIMIZE_LOAD_STORE_SPIRV) { - message = String.format("%20s", " ................ " + ColorsTerminal.RED + " [SPIRV OPTIMIZATION NOT SUPPORTED] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.RED + " [SPIRV OPTIMIZATION NOT SUPPORTED] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); failedCounter++; @@ -236,20 +229,21 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF } if (listener.getSummary().getFailures().stream().anyMatch(e -> (e.getException() instanceof TornadoDeviceFP64NotSupported))) { - message = String.format("%20s", " ................ " + ColorsTerminal.YELLOW + " [FP64 UNSUPPORTED FOR CURRENT DEVICE] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.YELLOW + " [FP64 UNSUPPORTED FOR CURRENT DEVICE] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); notSupported++; continue; } - message = String.format("%20s", " ................ " + ColorsTerminal.RED + " [FAILED] " + ColorsTerminal.RESET + "\n"); + message = String.format("%20s", SEPERATOR + ColorsTerminal.RED + " [FAILED] " + ColorsTerminal.RESET + "\n"); bufferConsole.append(message); bufferFile.append(message); failedCounter++; for (TestExecutionSummary.Failure failure : listener.getSummary().getFailures()) { - bufferConsole.append("\t\t\\_[REASON] " + failure.toString() + "\n"); - bufferFile.append("\t\t\\_[REASON] " + failure.toString() + "\n\t" + failure.getException().getStackTrace().toString() + "\n" + failure.toString() + "\n" + failure.getException()); + bufferConsole.append("\t\t\\_[REASON] " + failure.getException().getMessage() + "\n"); + bufferFile.append("\t\t\\_[REASON] " + failure.getException().getMessage() + "\n\t" + Arrays.toString(failure.getException().getStackTrace()) + "\n" + failure + .getException() + "\n"); } } } @@ -269,7 +263,7 @@ static void runTestVerbose(String klassName, String methodName) throws ClassNotF } } - static void runTestClassAndMethod(String klassName, String methodName) throws ClassNotFoundException { + static void runTestClassAndMethod(String klassName, String methodName) { LauncherDiscoveryRequest request = LauncherDiscoveryRequestBuilder.request().selectors(selectMethod(klassName, methodName)).build(); Launcher launcher = LauncherFactory.create(); SummaryGeneratingListener listener = new SummaryGeneratingListener(); @@ -278,7 +272,7 @@ static void runTestClassAndMethod(String klassName, String methodName) throws Cl printResult(listener); } - static void runTestClass(String klassName) throws ClassNotFoundException { + static void runTestClass(String klassName) { LauncherDiscoveryRequest request = LauncherDiscoveryRequestBuilder.request().selectors(selectMethod(klassName)).build(); Launcher launcher = LauncherFactory.create(); SummaryGeneratingListener listener = new SummaryGeneratingListener(); From 18b446848888a5bbb5d4719cfc899f6ce7c3a24e Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Mon, 18 Mar 2024 01:40:24 +0200 Subject: [PATCH 02/11] Remove junit-vintage-engine dependency from pom.xml --- tornado-unittests/pom.xml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tornado-unittests/pom.xml b/tornado-unittests/pom.xml index 600eb5f240..4e9529c224 100644 --- a/tornado-unittests/pom.xml +++ b/tornado-unittests/pom.xml @@ -26,11 +26,6 @@ junit-jupiter-engine ${junit.jupiter.version} - - - - - org.junit.platform junit-platform-launcher From 343e8d998b2b11d37236ee5185ef5320760b0f4f Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Mon, 18 Mar 2024 11:12:42 +0200 Subject: [PATCH 03/11] Refactor TestBatches for Junit5 --- .../unittests/batches/TestBatches.java | 245 ++++++++++++------ 1 file changed, 169 insertions(+), 76 deletions(-) diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java index 2ce753f771..d84d6c1e0a 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java @@ -243,24 +243,33 @@ public void test100MB() { @Test public void test300MB() { - long maxAllocMemory=checkMaxHeapAllocationOnDevice(300,MemoryUnit.MB); + long maxAllocMemory = checkMaxHeapAllocationOnDevice(300, MemoryUnit.MB); // Fill 1.0GB - int size=250_000_000; + int size = 250_000_000; // Or as much as we can - if(size*4>maxAllocMemory){size=(int)((maxAllocMemory/4/2)*0.9);}FloatArray arrayA=new FloatArray(size);FloatArray arrayB=new FloatArray(size); + if (size * 4 > maxAllocMemory) { + size = (int) ((maxAllocMemory / 4 / 2) * 0.9); + } + FloatArray arrayA = new FloatArray(size); + FloatArray arrayB = new FloatArray(size); - Random r=new Random();IntStream.range(0,arrayA.getSize()).sequential().forEach(idx->arrayA.set(idx,r.nextFloat())); + Random r = new Random(); + IntStream.range(0, arrayA.getSize()).sequential().forEach(idx -> arrayA.set(idx, r.nextFloat())); - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,arrayA) // - .task("t0",TestBatches::compute,arrayA,arrayB) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,arrayB); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, arrayA) // + .task("t0", TestBatches::compute, arrayA, arrayB) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, arrayB); - ImmutableTaskGraph immutableTaskGraph=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(immutableTaskGraph);executionPlan.withBatch("300MB") // Slots of 300 MB - .execute(); + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph); + executionPlan.withBatch("300MB") // Slots of 300 MB + .execute(); - for(int i=0;iexecutionPlan.withBatch("1MB").execute());executionPlan.freeDeviceMemory(); + assertThrows(TornadoBailoutRuntimeException.class, () -> { + checkMaxHeapAllocationOnDevice(5, MemoryUnit.MB); + IntArray a0 = new IntArray(2 * 1_000_000); + IntArray a1 = new IntArray(3 * 1_000_000); + + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); + executionPlan.freeDeviceMemory(); + }); } @Test public void testSameInputSizeAndTypeRestrictionJavaArrays() { // total input size mismatch for int[] - checkMaxHeapAllocationOnDevice(5,MemoryUnit.MB);int[]a0=new int[2*1_000_000];int[]a1=new int[3*1_000_000]; - - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot);assertThrows(TornadoBailoutRuntimeException.class,()->executionPlan.withBatch("1MB").execute());executionPlan.freeDeviceMemory(); + assertThrows(TornadoBailoutRuntimeException.class, () -> { + checkMaxHeapAllocationOnDevice(5, MemoryUnit.MB); + int[] a0 = new int[2 * 1_000_000]; + int[] a1 = new int[3 * 1_000_000]; + + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); + executionPlan.freeDeviceMemory(); + }); } @Test public void testSameInputTypeRestriction() { // IntArray is NOT compatible with LongArray even if the total input size is equal - checkMaxHeapAllocationOnDevice(6,MemoryUnit.MB);IntArray a0=new IntArray(4*1_000_000);LongArray a1=new LongArray(2*1_000_000); - - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot); - - assertThrows(TornadoBailoutRuntimeException.class,()->executionPlan.withBatch("1MB").execute());executionPlan.freeDeviceMemory(); + assertThrows(TornadoBailoutRuntimeException.class, () -> { + checkMaxHeapAllocationOnDevice(6, MemoryUnit.MB); + IntArray a0 = new IntArray(4 * 1_000_000); + LongArray a1 = new LongArray(2 * 1_000_000); + + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); + executionPlan.freeDeviceMemory(); + }); } @Test public void testSameInputTypeRestrictionJavaArrays() { // int[] is NOT compatible with long[] even if the total input size is equal - checkMaxHeapAllocationOnDevice(6,MemoryUnit.MB);int[]a0=new int[4*1_000_000];long[]a1=new long[2*1_000_000]; - - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot);assertThrows(TornadoBailoutRuntimeException.class,()->executionPlan.withBatch("1MB").execute());executionPlan.freeDeviceMemory(); + assertThrows(TornadoBailoutRuntimeException.class, () -> { + + checkMaxHeapAllocationOnDevice(6, MemoryUnit.MB); + int[] a0 = new int[4 * 1_000_000]; + long[] a1 = new long[2 * 1_000_000]; + + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); + executionPlan.freeDeviceMemory(); + }); } @Test public void testSameInputSize() { // IntArray is compatible with FloatArray for the same # of elements - checkMaxHeapAllocationOnDevice(4,MemoryUnit.MB);IntArray a0=new IntArray(2*1_000_000);IntStream.range(0,a0.getSize()).forEach(i->a0.set(i,i));FloatArray a1=new FloatArray(2*1_000_000); + checkMaxHeapAllocationOnDevice(4, MemoryUnit.MB); + IntArray a0 = new IntArray(2 * 1_000_000); + IntStream.range(0, a0.getSize()).forEach(i -> a0.set(i, i)); + FloatArray a1 = new FloatArray(2 * 1_000_000); - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot);executionPlan.withBatch("1MB").execute(); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); - for(int i=0;ia0[i]=i);float[]a1=new float[2*1_000_000]; + checkMaxHeapAllocationOnDevice(4, MemoryUnit.MB); + int[] a0 = new int[2 * 1_000_000]; + IntStream.range(0, a0.length).forEach(i -> a0[i] = i); + float[] a1 = new float[2 * 1_000_000]; - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot);executionPlan.withBatch("1MB").execute(); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); - for(int i=0;ia0[i]=i);FloatArray a1=new FloatArray(2*1_000_000); + checkMaxHeapAllocationOnDevice(4, MemoryUnit.MB); + int[] a0 = new int[2 * 1_000_000]; + IntStream.range(0, a0.length).forEach(i -> a0[i] = i); + FloatArray a1 = new FloatArray(2 * 1_000_000); - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot);executionPlan.withBatch("1MB").execute(); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); - for(int i=0;ia0.set(i,i));float[]a1=new float[2*1_000_000]; + checkMaxHeapAllocationOnDevice(4, MemoryUnit.MB); + IntArray a0 = new IntArray(2 * 1_000_000); + IntStream.range(0, a0.getSize()).forEach(i -> a0.set(i, i)); + float[] a1 = new float[2 * 1_000_000]; - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot);executionPlan.withBatch("1MB").execute(); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); - for(int i=0;ia0[i]=i);IntArray a1=new IntArray(2*1_000_000); + checkMaxHeapAllocationOnDevice(4, MemoryUnit.MB); + int[] a0 = new int[2 * 1_000_000]; + IntStream.range(0, a0.length).forEach(i -> a0[i] = i); + IntArray a1 = new IntArray(2 * 1_000_000); - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot);executionPlan.withBatch("1MB").execute(); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); - for(int i=0;ia0.set(i,i));int[]a1=new int[2*1_000_000]; + checkMaxHeapAllocationOnDevice(4, MemoryUnit.MB); + IntArray a0 = new IntArray(2 * 1_000_000); + IntStream.range(0, a0.getSize()).forEach(i -> a0.set(i, i)); + int[] a1 = new int[2 * 1_000_000]; - TaskGraph taskGraph=new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION,a0) // - .task("t0",TestBatches::compute,a0,a1) // - .transferToHost(DataTransferMode.EVERY_EXECUTION,a1);ImmutableTaskGraph snapshot=taskGraph.snapshot();TornadoExecutionPlan executionPlan=new TornadoExecutionPlan(snapshot);executionPlan.withBatch("1MB").execute(); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a0) // + .task("t0", TestBatches::compute, a0, a1) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, a1); + ImmutableTaskGraph snapshot = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(snapshot); + executionPlan.withBatch("1MB").execute(); - for(int i=0;i Date: Thu, 21 Mar 2024 15:44:41 +0200 Subject: [PATCH 04/11] Fix merge conflicts --- .../tornado/unittests/api/TestDevices.java | 21 ++++++++++++------- .../unittests/vector/api/TestVectorAPI.java | 17 ++++++++------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestDevices.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestDevices.java index 3ab94433cb..17eaa86869 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestDevices.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestDevices.java @@ -17,13 +17,14 @@ */ package uk.ac.manchester.tornado.unittests.api; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.TornadoBackend; import uk.ac.manchester.tornado.api.TornadoDeviceMap; @@ -49,18 +50,22 @@ public class TestDevices extends TornadoTestBase { * We ask, on purpose, for a backend index that does not exist to * check that the exception {@link TornadoBackendNotFound} in thrown. */ - @Test(expected = TornadoBackendNotFound.class) + @Test public void test01() { - TornadoDevice device = TornadoExecutionPlan.getDevice(100, 0); + assertThrows(TornadoBackendNotFound.class, () -> { + TornadoDevice device = TornadoExecutionPlan.getDevice(100, 0); + }); } /** * We ask, on purpose, for a device index that does not exist to * check that the exception {@link TornadoDeviceNotFound} in thrown. */ - @Test(expected = TornadoDeviceNotFound.class) + @Test public void test02() { - TornadoDevice device = TornadoExecutionPlan.getDevice(0, 100); + assertThrows(TornadoDeviceNotFound.class, () -> { + TornadoDevice device = TornadoExecutionPlan.getDevice(0, 100); + }); } /** diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vector/api/TestVectorAPI.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vector/api/TestVectorAPI.java index 4ffd1018bd..120240ab64 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vector/api/TestVectorAPI.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vector/api/TestVectorAPI.java @@ -17,13 +17,14 @@ */ package uk.ac.manchester.tornado.unittests.vector.api; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + import java.nio.ByteOrder; import java.util.Random; import java.util.stream.IntStream; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorSpecies; @@ -52,7 +53,7 @@ public static float randFloat(float min, float max, Random rand) { return rand.nextFloat() * (max - min) + min; } - @BeforeClass + @BeforeAll public static void setUpBeforeClass() { arrayA = new FloatArray(SIZE); arrayB = new FloatArray(SIZE); @@ -108,7 +109,7 @@ private float[] parallelVectorAdd(FloatArray vector1, FloatArray vector2, Vector public void test64BitVectors() { VectorSpecies species = FloatVector.SPECIES_64; float[] result = parallelVectorAdd(arrayA, arrayB, species); - Assert.assertArrayEquals(result, referenceResult.toHeapArray(), DELTA); + assertArrayEquals(result, referenceResult.toHeapArray(), DELTA); } /** @@ -118,7 +119,7 @@ public void test64BitVectors() { public void test128BitVectors() { VectorSpecies species = FloatVector.SPECIES_128; float[] result = parallelVectorAdd(arrayA, arrayB, species); - Assert.assertArrayEquals(result, referenceResult.toHeapArray(), DELTA); + assertArrayEquals(result, referenceResult.toHeapArray(), DELTA); } /** @@ -128,7 +129,7 @@ public void test128BitVectors() { public void test256BitVectors() { VectorSpecies species = FloatVector.SPECIES_256; float[] result = parallelVectorAdd(arrayA, arrayB, species); - Assert.assertArrayEquals(result, referenceResult.toHeapArray(), DELTA); + assertArrayEquals(result, referenceResult.toHeapArray(), DELTA); } /** @@ -138,6 +139,6 @@ public void test256BitVectors() { public void test512BitVectors() { VectorSpecies species = FloatVector.SPECIES_512; float[] result = parallelVectorAdd(arrayA, arrayB, species); - Assert.assertArrayEquals(result, referenceResult.toHeapArray(), DELTA); + assertArrayEquals(result, referenceResult.toHeapArray(), DELTA); } } From 2aadad478f76d008af77cafc4289a5be2d7a3c97 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 16 Apr 2024 14:05:16 +0300 Subject: [PATCH 05/11] Update unit tests to use JUnit 5 --- .../tornado/unittests/api/TestConcat.java | 52 +++++++++---------- .../unittests/api/TestInitDataTypes.java | 4 +- .../tornado/unittests/api/TestSlice.java | 22 ++++---- .../tensors/TestTensorAPIWithOnnx.java | 7 +-- .../unittests/tensors/TestTensorTypes.java | 28 +++++----- .../unittests/vectortypes/TestHalfFloats.java | 9 ++-- 6 files changed, 65 insertions(+), 57 deletions(-) diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestConcat.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestConcat.java index f663d1a37e..4c3aea8f29 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestConcat.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestConcat.java @@ -17,7 +17,7 @@ */ package uk.ac.manchester.tornado.unittests.api; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.CharArray; @@ -29,7 +29,7 @@ import uk.ac.manchester.tornado.api.types.arrays.ShortArray; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * How to run? @@ -55,14 +55,14 @@ public void testFloatArrayConcat() { FloatArray c = FloatArray.concat(a, b, e); for (int i = 0; i < a.getSize(); i++) { - assertEquals("Mismatch in first part of c", 100.0f, c.get(i), 0.0f); + assertEquals(100.0f, c.get(i), 0.0f,"Mismatch in first part of c"); } for (int i = 0; i < b.getSize(); i++) { - assertEquals("Mismatch in second part of c", 5.0f, c.get(a.getSize() + i), 0.0f); + assertEquals( 5.0f, c.get(a.getSize() + i), 0.0f, "Mismatch in second part of c"); } for (int i = 0; i < e.getSize(); i++) { - assertEquals("Mismatch in third part of c", 12f, c.get(a.getSize() + b.getSize() + i), 0.0f); + assertEquals( 12f, c.get(a.getSize() + b.getSize() + i), 0.0f,"Mismatch in third part of c"); } } @@ -80,14 +80,14 @@ public void testDoubleArrayConcat() { DoubleArray c = DoubleArray.concat(a, b, e); for (int i = 0; i < a.getSize(); i++) { - assertEquals("Mismatch in first part of c", 100.0f, c.get(i), 0.0f); + assertEquals(100.0f, c.get(i), 0.0f,"Mismatch in first part of c"); } for (int i = 0; i < b.getSize(); i++) { - assertEquals("Mismatch in second part of c", 5.0f, c.get(a.getSize() + i), 0.0f); + assertEquals( 5.0f, c.get(a.getSize() + i), 0.0f, "Mismatch in second part of c"); } for (int i = 0; i < e.getSize(); i++) { - assertEquals("Mismatch in third part of c", 12f, c.get(a.getSize() + b.getSize() + i), 0.0f); + assertEquals( 12f, c.get(a.getSize() + b.getSize() + i), 0.0f,"Mismatch in third part of c"); } } @@ -105,14 +105,14 @@ public void testByteArrayConcat() { ByteArray c = ByteArray.concat(a, b, e); for (int i = 0; i < a.getSize(); i++) { - assertEquals("Mismatch in first part of c", 100, c.get(i), 0.0f); + assertEquals(100.0f, c.get(i), 0.0f,"Mismatch in first part of c"); } for (int i = 0; i < b.getSize(); i++) { - assertEquals("Mismatch in second part of c", 5, c.get(a.getSize() + i), 0.0f); + assertEquals( 5.0f, c.get(a.getSize() + i), 0.0f, "Mismatch in second part of c"); } for (int i = 0; i < e.getSize(); i++) { - assertEquals("Mismatch in third part of c", 12, c.get(a.getSize() + b.getSize() + i), 0.0f); + assertEquals( 12f, c.get(a.getSize() + b.getSize() + i), 0.0f,"Mismatch in third part of c"); } } @@ -130,14 +130,14 @@ public void testLongArrayConcat() { LongArray c = LongArray.concat(a, b, e); for (int i = 0; i < a.getSize(); i++) { - assertEquals("Mismatch in first part of c", 100.0f, c.get(i), 0.0f); + assertEquals(100.0f, c.get(i), 0.0f,"Mismatch in first part of c"); } for (int i = 0; i < b.getSize(); i++) { - assertEquals("Mismatch in second part of c", 5.0f, c.get(a.getSize() + i), 0.0f); + assertEquals( 5.0f, c.get(a.getSize() + i), 0.0f, "Mismatch in second part of c"); } for (int i = 0; i < e.getSize(); i++) { - assertEquals("Mismatch in third part of c", 12f, c.get(a.getSize() + b.getSize() + i), 0.0f); + assertEquals( 12f, c.get(a.getSize() + b.getSize() + i), 0.0f,"Mismatch in third part of c"); } } @@ -155,14 +155,14 @@ public void testIntArrayConcat() { IntArray c = IntArray.concat(a, b, e); for (int i = 0; i < a.getSize(); i++) { - assertEquals("Mismatch in first part of c", 100.0f, c.get(i), 0.0f); + assertEquals(100.0f, c.get(i), 0.0f,"Mismatch in first part of c"); } for (int i = 0; i < b.getSize(); i++) { - assertEquals("Mismatch in second part of c", 5.0f, c.get(a.getSize() + i), 0.0f); + assertEquals( 5.0f, c.get(a.getSize() + i), 0.0f, "Mismatch in second part of c"); } for (int i = 0; i < e.getSize(); i++) { - assertEquals("Mismatch in third part of c", 12f, c.get(a.getSize() + b.getSize() + i), 0.0f); + assertEquals( 12f, c.get(a.getSize() + b.getSize() + i), 0.0f,"Mismatch in third part of c"); } } @@ -180,14 +180,14 @@ public void testShortArrayConcat() { ShortArray c = ShortArray.concat(a, b, e); for (int i = 0; i < a.getSize(); i++) { - assertEquals("Mismatch in first part of c", 100.0f, c.get(i), 0.0f); + assertEquals(100.0f, c.get(i), 0.0f,"Mismatch in first part of c"); } for (int i = 0; i < b.getSize(); i++) { - assertEquals("Mismatch in second part of c", 5.0f, c.get(a.getSize() + i), 0.0f); + assertEquals( 5.0f, c.get(a.getSize() + i), 0.0f, "Mismatch in second part of c"); } for (int i = 0; i < e.getSize(); i++) { - assertEquals("Mismatch in third part of c", 12f, c.get(a.getSize() + b.getSize() + i), 0.0f); + assertEquals( 12f, c.get(a.getSize() + b.getSize() + i), 0.0f,"Mismatch in third part of c"); } } @@ -205,14 +205,14 @@ public void testCharArrayConcat() { CharArray c = CharArray.concat(a, b, e); for (int i = 0; i < a.getSize(); i++) { - assertEquals("Mismatch in first part of c", 100.0f, c.get(i), 0.0f); + assertEquals(100.0f, c.get(i), 0.0f,"Mismatch in first part of c"); } for (int i = 0; i < b.getSize(); i++) { - assertEquals("Mismatch in second part of c", 5.0f, c.get(a.getSize() + i), 0.0f); + assertEquals( 5.0f, c.get(a.getSize() + i), 0.0f, "Mismatch in second part of c"); } for (int i = 0; i < e.getSize(); i++) { - assertEquals("Mismatch in third part of c", 12f, c.get(a.getSize() + b.getSize() + i), 0.0f); + assertEquals( 12f, c.get(a.getSize() + b.getSize() + i), 0.0f,"Mismatch in third part of c"); } } @@ -230,14 +230,14 @@ public void testHalfFloatArrayConcat() { HalfFloatArray c = HalfFloatArray.concat(a, b, e); for (int i = 0; i < a.getSize(); i++) { - assertEquals("Mismatch in first part of c", 100.0f, c.get(i).getFloat32(), 0.0f); + assertEquals(100.0f, c.get(i).getFloat32(), 0.0f,"Mismatch in first part of c"); } for (int i = 0; i < b.getSize(); i++) { - assertEquals("Mismatch in second part of c", 5.0f, c.get(a.getSize() + i).getFloat32(), 0.0f); + assertEquals( 5.0f, c.get(a.getSize() + i).getFloat32(), 0.0f, "Mismatch in second part of c"); } for (int i = 0; i < e.getSize(); i++) { - assertEquals("Mismatch in third part of c", 12f, c.get(a.getSize() + b.getSize() + i).getFloat32(), 0.0f); + assertEquals( 12f, c.get(a.getSize() + b.getSize() + i).getFloat32(), 0.0f,"Mismatch in third part of c"); } } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestInitDataTypes.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestInitDataTypes.java index 7037a3be7d..c22924ee1d 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestInitDataTypes.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestInitDataTypes.java @@ -17,7 +17,7 @@ */ package uk.ac.manchester.tornado.unittests.api; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; @@ -33,7 +33,7 @@ import uk.ac.manchester.tornado.api.types.arrays.ShortArray; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * How to run? diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestSlice.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestSlice.java index 071a653556..7af2c98cf7 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestSlice.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestSlice.java @@ -17,7 +17,7 @@ */ package uk.ac.manchester.tornado.unittests.api; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.CharArray; @@ -29,7 +29,7 @@ import uk.ac.manchester.tornado.api.types.arrays.ShortArray; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * How to run? @@ -59,11 +59,13 @@ public void testFloatArraySlice() { FloatArray slice = c.slice(256, numElements); for (int i = 0; i < slice.getSize(); i++) { - assertEquals("Mismatch in second part of slice", 5.0f, slice.get(i), 0.0f); + assertEquals(5.0f, slice.get(i), 0.0f, "Mismatch in second part of slice"); } } + + @Test public void testDoubleArraySlice() { @@ -80,7 +82,7 @@ public void testDoubleArraySlice() { DoubleArray slice = c.slice(256, numElements); for (int i = 0; i < slice.getSize(); i++) { - assertEquals("Mismatch in second part of slice", 5.0d, slice.get(i), 0.0f); + assertEquals(5.0f, slice.get(i), 0.0f, "Mismatch in second part of slice"); } } @@ -100,7 +102,7 @@ public void testByteArraySlice() { ByteArray slice = c.slice(256, numElements); for (int i = 0; i < slice.getSize(); i++) { - assertEquals("Mismatch in second part of slice", 5.0, slice.get(i), 0.0f); + assertEquals(5.0f, slice.get(i), 0.0f, "Mismatch in second part of slice"); } } @@ -119,7 +121,7 @@ public void testLongArraySlice() { LongArray slice = c.slice(256, numElements); for (int i = 0; i < slice.getSize(); i++) { - assertEquals("Mismatch in second part of slice", 5.0, slice.get(i), 0.0f); + assertEquals(5.0f, slice.get(i), 0.0f, "Mismatch in second part of slice"); } } @@ -139,7 +141,7 @@ public void testIntArraySlice() { IntArray slice = c.slice(256, numElements); for (int i = 0; i < slice.getSize(); i++) { - assertEquals("Mismatch in second part of slice", 5.0d, slice.get(i), 0.0f); + assertEquals(5.0f, slice.get(i), 0.0f, "Mismatch in second part of slice"); } } @@ -159,7 +161,7 @@ public void testShortArraySlice() { ShortArray slice = c.slice(256, numElements); for (int i = 0; i < slice.getSize(); i++) { - assertEquals("Mismatch in second part of slice", 5.0, slice.get(i), 0.0f); + assertEquals(5.0f, slice.get(i), 0.0f, "Mismatch in second part of slice"); } } @@ -178,7 +180,7 @@ public void testCharArraySlice() { CharArray slice = c.slice(256, numElements); for (int i = 0; i < slice.getSize(); i++) { - assertEquals("Mismatch in second part of slice", 5.0d, slice.get(i), 0.0f); + assertEquals(5.0f, slice.get(i), 0.0f, "Mismatch in second part of slice"); } } @@ -198,7 +200,7 @@ public void testHalfFloatArraySlice() { HalfFloatArray slice = c.slice(256, numElements); for (int i = 0; i < slice.getSize(); i++) { - assertEquals("Mismatch in second part of slice", 5.0d, slice.get(i).getFloat32(), 0.0f); + assertEquals(5.0f, slice.get(i).getFloat32(), 0.0f, "Mismatch in second part of slice"); } } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tensors/TestTensorAPIWithOnnx.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tensors/TestTensorAPIWithOnnx.java index edc24e779e..5543af82f8 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tensors/TestTensorAPIWithOnnx.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tensors/TestTensorAPIWithOnnx.java @@ -23,8 +23,7 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.types.tensors.Shape; import uk.ac.manchester.tornado.api.types.tensors.TensorFP32; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; @@ -40,6 +39,8 @@ import java.util.Map; import java.util.Optional; +import static org.junit.jupiter.api.Assertions.assertNotNull; + public class TestTensorAPIWithOnnx extends TornadoTestBase { private final String INPUT_TENSOR_NAME = "data"; @@ -97,7 +98,7 @@ public void testOnnxCompatibility() throws OrtException, IOException { } } finally { - Assert.assertNotNull(outputTensor); + assertNotNull(outputTensor); } } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tensors/TestTensorTypes.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tensors/TestTensorTypes.java index d0ec4be377..a87656a540 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tensors/TestTensorTypes.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/tensors/TestTensorTypes.java @@ -17,9 +17,8 @@ */ package uk.ac.manchester.tornado.unittests.tensors; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; @@ -36,6 +35,9 @@ import uk.ac.manchester.tornado.api.types.tensors.TensorInt64; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; +import static org.junit.jupiter.api.Assertions.assertEquals; + + public class TestTensorTypes extends TornadoTestBase { public static void tensorAdditionFloat16(TensorFP16 tensorA, TensorFP16 tensorB, TensorFP16 tensorC) { @@ -88,10 +90,10 @@ public void testHelloTensorAPI() { tensorA.init(new HalfFloat(1f)); - Assert.assertEquals("Expected shape does not match", "Shape{dimensions=[64, 64, 64]}", tensorA.getShape().toString()); - Assert.assertEquals("Expected data type does not match", "HALF_FLOAT", tensorA.getDTypeAsString()); - Assert.assertEquals("Expected TensorFlow shape string does not match", "[64,64,64]", tensorA.getShape().toTensorFlowShapeString()); - Assert.assertEquals("Expected ONNX shape string does not match", "{dim_0: 64, dim_1: 64, dim_2: 64}", tensorA.getShape().toONNXShapeString()); + assertEquals("Expected shape does not match", "Shape{dimensions=[64, 64, 64]}", tensorA.getShape().toString()); + assertEquals("Expected data type does not match", "HALF_FLOAT", tensorA.getDTypeAsString()); + assertEquals("Expected TensorFlow shape string does not match", "[64,64,64]", tensorA.getShape().toTensorFlowShapeString()); + assertEquals("Expected ONNX shape string does not match", "{dim_0: 64, dim_1: 64, dim_2: 64}", tensorA.getShape().toONNXShapeString()); } @Test @@ -124,7 +126,7 @@ public void testTensorFloat16Add() { executionPlan.execute(); for (int i = 0; i < tensorC.getSize(); i++) { - Assert.assertEquals(tensorC.get(i).getFloat32(), HalfFloat.add(tensorA.get(i), tensorB.get(i)).getFloat32(), 0.00f); + assertEquals(tensorC.get(i).getFloat32(), HalfFloat.add(tensorA.get(i), tensorB.get(i)).getFloat32(), 0.00f); } } @@ -159,7 +161,7 @@ public void testTensorFloat32Add() { executionPlan.execute(); for (int i = 0; i < tensorC.getSize(); i++) { - Assert.assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); + assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); } } @@ -194,7 +196,7 @@ public void testTensorFloat64Add() { executionPlan.execute(); for (int i = 0; i < tensorC.getSize(); i++) { - Assert.assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); + assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); } } @@ -229,7 +231,7 @@ public void testTensorInt16Add() { executionPlan.execute(); for (int i = 0; i < tensorC.getSize(); i++) { - Assert.assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); + assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); } } @@ -264,7 +266,7 @@ public void testTensorInt32Add() { executionPlan.execute(); for (int i = 0; i < tensorC.getSize(); i++) { - Assert.assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); + assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); } } @@ -298,7 +300,7 @@ public void testTensorInt64Add() { executionPlan.execute(); for (int i = 0; i < tensorC.getSize(); i++) { - Assert.assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); + assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); } } @@ -332,7 +334,7 @@ public void testTensorByte() { executionPlan.execute(); for (int i = 0; i < tensorC.getSize(); i++) { - Assert.assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); + assertEquals(tensorC.get(i), tensorA.get(i) + tensorB.get(i), 0.00f); } } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vectortypes/TestHalfFloats.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vectortypes/TestHalfFloats.java index 7536b34240..961839736c 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vectortypes/TestHalfFloats.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vectortypes/TestHalfFloats.java @@ -17,7 +17,8 @@ */ package uk.ac.manchester.tornado.unittests.vectortypes; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; @@ -39,8 +40,9 @@ import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; import java.util.Random; +import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** *

@@ -905,7 +907,8 @@ public void testInternalSetMethod04() { } } - @Test(timeout = 1000) //timeout of 1sec + @Test + @Timeout(value = 1000, unit= TimeUnit.MILLISECONDS) //timeout of 1sec public void testAllocationIssue() { int size = 8192 * 4096; From c0d9a2c2806512462f26e17165e9825d6149c06e Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 23 Apr 2024 12:52:10 +0300 Subject: [PATCH 06/11] Partial fix for segfaults in OpenCL --- .../drivers/opencl/mm/OCLMemorySegmentWrapper.java | 2 +- .../TestMultiThreadedExecutionPlans.java | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java index c6ea33dc75..ae5384b9f7 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java @@ -166,7 +166,7 @@ public List enqueueWrite(long executionPlanId, Object reference, long b } else { internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), 0, TornadoNativeArray.ARRAY_HEADER, segment.address(), 0, (useDeps) ? events : null); returnEvents.add(internalEvent); - internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), bufferOffset + TornadoNativeArray.ARRAY_HEADER, bufferSize, segment.address(), + internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), bufferOffset + TornadoNativeArray.ARRAY_HEADER, batchSize, segment.address(), hostOffset + TornadoNativeArray.ARRAY_HEADER, (useDeps) ? events : null); } returnEvents.add(internalEvent); diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/multithreaded/TestMultiThreadedExecutionPlans.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/multithreaded/TestMultiThreadedExecutionPlans.java index 0293caf7cd..7ac36c081d 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/multithreaded/TestMultiThreadedExecutionPlans.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/multithreaded/TestMultiThreadedExecutionPlans.java @@ -17,10 +17,10 @@ */ package uk.ac.manchester.tornado.unittests.multithreaded; -import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -172,13 +172,13 @@ public void test03() { try { t1.join(); } catch (InterruptedException e) { - assertTrue("Error", false); + assertTrue(false, "Error"); } try { t2.join(); } catch (InterruptedException e) { - assertTrue("Error", false); + assertTrue(false, "Error"); } } @@ -197,13 +197,13 @@ public void test04() { try { t1.join(); } catch (InterruptedException e) { - assertTrue("Error", false); + assertTrue(false, "Error"); } try { t2.join(); } catch (InterruptedException e) { - assertTrue("Error", false); + assertTrue(false, "Error"); } } From 85205591432709f26b55d3854ab394f6de7c7fbe Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 23 Apr 2024 18:18:49 +0300 Subject: [PATCH 07/11] Refactor memory allocation and test methods --- .../opencl/mm/OCLMemorySegmentWrapper.java | 4 +- .../unittests/batches/TestBatches.java | 40 ++++++++++--------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java index ae5384b9f7..858cc331ac 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java @@ -68,7 +68,7 @@ public OCLMemorySegmentWrapper(long bufferSize, OCLDeviceContext deviceContext, this.bufferSize = bufferSize; this.bufferId = INIT_VALUE; this.bufferOffset = 0; - onDevice = false; + onDevice = false; } @Override @@ -184,7 +184,7 @@ public void allocate(Object reference, long batchSize) throws TornadoOutOfMemory bufferId = deviceContext.getBufferProvider().getOrAllocateBufferWithSize(bufferSize); } else { bufferSize = batchSize; - bufferId = deviceContext.getBufferProvider().getOrAllocateBufferWithSize(bufferSize + TornadoNativeArray.ARRAY_HEADER); + bufferId = deviceContext.getBufferProvider().getOrAllocateBufferWithSize(batchSize + TornadoNativeArray.ARRAY_HEADER); } if (bufferSize <= 0) { diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java index 6015e9f0cb..4f3965a1cb 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java @@ -838,19 +838,19 @@ public void testSameInputSizeAndTypeTornadoToJava() { } @Test - public void testBatchNotEven() { + public void testBatchNotEven2() { checkMaxHeapAllocationOnDevice(64, MemoryUnit.MB); // Allocate ~ 64MB FloatArray array = new FloatArray(1024 * 1024 * 16); - FloatArray arraySeq = new FloatArray(1024 * 1024 * 16); - for (int i = 0; i < arraySeq.getSize(); i++) { - arraySeq.set(i, i); - } + FloatArray array2 = new FloatArray(1024 * 1024 * 16); + array.init(1.0f); + array2.init(1.0f); TaskGraph taskGraph = new TaskGraph("s0") // - .task("t0", TestBatches::parallelInitialization, array) // + .transferToDevice(DataTransferMode.EVERY_EXECUTION, array) // .task("t1", TestBatches::compute2, array) // + .task("t2", TestBatches::compute2, array) // .transferToHost(DataTransferMode.EVERY_EXECUTION, array); TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(taskGraph.snapshot()); @@ -858,13 +858,13 @@ public void testBatchNotEven() { .execute(); for (int i = 0; i < array.getSize(); i++) { - assertEquals(arraySeq.get(i) * 2, array.get(i), 0.01f); + assertEquals(array2.get(i) * 4, array.get(i), 0.01f); } executionPlan.freeDeviceMemory(); } @Test - public void testBatchNotEven2() { + public void testBatchNotEven2Lazy() { checkMaxHeapAllocationOnDevice(64, MemoryUnit.MB); // Allocate ~ 64MB @@ -880,43 +880,45 @@ public void testBatchNotEven2() { .transferToHost(DataTransferMode.EVERY_EXECUTION, array); TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(taskGraph.snapshot()); - executionPlan.withBatch("10MB") // Batches of 10MB + TornadoExecutionResult tornadoExecutionResult = executionPlan.withBatch("10MB") // Batches of 10MB .execute(); + tornadoExecutionResult.transferToHost(array); + for (int i = 0; i < array.getSize(); i++) { assertEquals(array2.get(i) * 4, array.get(i), 0.01f); } executionPlan.freeDeviceMemory(); } + @Test - public void testBatchNotEven2Lazy() { + public void testBatchNotEven() { checkMaxHeapAllocationOnDevice(64, MemoryUnit.MB); // Allocate ~ 64MB FloatArray array = new FloatArray(1024 * 1024 * 16); - FloatArray array2 = new FloatArray(1024 * 1024 * 16); - array.init(1.0f); - array2.init(1.0f); + FloatArray arraySeq = new FloatArray(1024 * 1024 * 16); + for (int i = 0; i < arraySeq.getSize(); i++) { + arraySeq.set(i, i); + } TaskGraph taskGraph = new TaskGraph("s0") // - .transferToDevice(DataTransferMode.EVERY_EXECUTION, array) // + .task("t0", TestBatches::parallelInitialization, array) // .task("t1", TestBatches::compute2, array) // - .task("t2", TestBatches::compute2, array) // .transferToHost(DataTransferMode.EVERY_EXECUTION, array); TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(taskGraph.snapshot()); - TornadoExecutionResult tornadoExecutionResult = executionPlan.withBatch("10MB") // Batches of 10MB + executionPlan.withBatch("10MB") // Batches of 10MB .execute(); - tornadoExecutionResult.transferToHost(array); - for (int i = 0; i < array.getSize(); i++) { - assertEquals(array2.get(i) * 4, array.get(i), 0.01f); + assertEquals(arraySeq.get(i) * 2, array.get(i), 0.01f); } executionPlan.freeDeviceMemory(); } + private long checkMaxHeapAllocationOnDevice(int size, MemoryUnit memoryUnit) throws UnsupportedConfigurationException { long maxAllocMemory = getTornadoRuntime().getDefaultDevice().getDeviceContext().getMemoryManager().getHeapSize(); From 3d25f40d8bcba858b397b6f76ebefd71fada3007 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Mon, 29 Apr 2024 13:19:50 +0300 Subject: [PATCH 08/11] Migrate unit tests to JUnit5 --- .../tornado/unittests/executor/TestExecutor.java | 10 +++++++--- .../unittests/memory/TestStressDeviceMemory.java | 8 ++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/executor/TestExecutor.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/executor/TestExecutor.java index 0fa1221962..42a7412ad6 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/executor/TestExecutor.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/executor/TestExecutor.java @@ -17,13 +17,17 @@ */ package uk.ac.manchester.tornado.unittests.executor; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import java.util.Arrays; -import org.junit.Test; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/memory/TestStressDeviceMemory.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/memory/TestStressDeviceMemory.java index 5273487411..e7705204f1 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/memory/TestStressDeviceMemory.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/memory/TestStressDeviceMemory.java @@ -18,11 +18,8 @@ package uk.ac.manchester.tornado.unittests.memory; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; @@ -32,6 +29,9 @@ import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + /** * How to test? * From 582228e2f26c02ad880aeae34b0f015065f33c27 Mon Sep 17 00:00:00 2001 From: Florin Blanaru Date: Mon, 20 May 2024 11:27:20 +0300 Subject: [PATCH 09/11] Fix PTX code cache using an invalid module --- .../drivers/opencl/mm/OCLMemorySegmentWrapper.java | 5 +++-- .../tornado/drivers/ptx/PTXCodeCache.java | 3 ++- .../drivers/ptx/mm/PTXMemorySegmentWrapper.java | 8 +++++--- tornado-unittests/pom.xml | 8 +------- tornado-unittests/src/main/java/module-info.java | 1 + .../tornado/unittests/batches/TestBatches.java | 14 ++++++++------ 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java index 858cc331ac..6992590c35 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java @@ -160,13 +160,14 @@ public List enqueueWrite(long executionPlanId, Object reference, long b MemorySegment segment; segment = getSegmentWithHeader(reference); + final long numBytes = getSizeSubRegionSize() > 0 ? getSizeSubRegionSize() : bufferSize; int internalEvent; if (batchSize <= 0) { - internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), bufferOffset, bufferSize, segment.address(), hostOffset, (useDeps) ? events : null); + internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), bufferOffset, numBytes, segment.address(), hostOffset, (useDeps) ? events : null); } else { internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), 0, TornadoNativeArray.ARRAY_HEADER, segment.address(), 0, (useDeps) ? events : null); returnEvents.add(internalEvent); - internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), bufferOffset + TornadoNativeArray.ARRAY_HEADER, batchSize, segment.address(), + internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), bufferOffset + TornadoNativeArray.ARRAY_HEADER, numBytes, segment.address(), hostOffset + TornadoNativeArray.ARRAY_HEADER, (useDeps) ? events : null); } returnEvents.add(internalEvent); diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java index 87c17af5d4..ee1a14a668 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java @@ -41,7 +41,8 @@ public PTXCodeCache(PTXDeviceContext deviceContext) { public PTXInstalledCode installSource(String name, byte[] targetCode, String resolvedMethodName, boolean debugKernel) { - if (!cache.containsKey(name)) { + PTXInstalledCode installedCode = cache.get(name); + if (installedCode == null || !installedCode.isValid()) { if (debugKernel) { RuntimeUtilities.dumpKernel(targetCode); } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemorySegmentWrapper.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemorySegmentWrapper.java index eb1e32224d..df621847fd 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemorySegmentWrapper.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemorySegmentWrapper.java @@ -138,11 +138,12 @@ public void write(long executionPlanId, Object reference) { public int enqueueRead(long executionPlanId, Object reference, long hostOffset, int[] events, boolean useDeps) { MemorySegment segment = getSegmentWithHeader(reference); + final long numBytes = getSizeSubRegionSize() > 0 ? getSizeSubRegionSize() : bufferSize; final int returnEvent; if (batchSize <= 0) { returnEvent = deviceContext.enqueueReadBuffer(executionPlanId, toBuffer(), bufferSize, segment.address(), hostOffset, (useDeps) ? events : null); } else { - returnEvent = deviceContext.enqueueReadBuffer(executionPlanId, toBuffer() + TornadoNativeArray.ARRAY_HEADER, bufferSize - TornadoNativeArray.ARRAY_HEADER, segment.address(), hostOffset, + returnEvent = deviceContext.enqueueReadBuffer(executionPlanId, toBuffer() + TornadoNativeArray.ARRAY_HEADER, numBytes, segment.address(), hostOffset, (useDeps) ? events : null); } return useDeps ? returnEvent : -1; @@ -154,13 +155,14 @@ public List enqueueWrite(long executionPlanId, Object reference, long b MemorySegment segment = getSegmentWithHeader(reference); + final long numBytes = getSizeSubRegionSize() > 0 ? getSizeSubRegionSize() : bufferSize; int internalEvent; if (batchSize <= 0) { - internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), bufferSize, segment.address(), hostOffset, (useDeps) ? events : null); + internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), numBytes, segment.address(), hostOffset, (useDeps) ? events : null); } else { internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer(), TornadoNativeArray.ARRAY_HEADER, segment.address(), 0, (useDeps) ? events : null); returnEvents.add(internalEvent); - internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer() + TornadoNativeArray.ARRAY_HEADER, bufferSize, segment.address(), hostOffset + TornadoNativeArray.ARRAY_HEADER, + internalEvent = deviceContext.enqueueWriteBuffer(executionPlanId, toBuffer() + TornadoNativeArray.ARRAY_HEADER, numBytes, segment.address(), hostOffset + TornadoNativeArray.ARRAY_HEADER, (useDeps) ? events : null); } returnEvents.add(internalEvent); diff --git a/tornado-unittests/pom.xml b/tornado-unittests/pom.xml index 17c0bf8b6d..b3d0d4cd18 100644 --- a/tornado-unittests/pom.xml +++ b/tornado-unittests/pom.xml @@ -48,13 +48,7 @@ org.junit.jupiter junit-jupiter - RELEASE - compile - - - org.junit.jupiter - junit-jupiter - RELEASE + ${junit.jupiter.version} compile diff --git a/tornado-unittests/src/main/java/module-info.java b/tornado-unittests/src/main/java/module-info.java index 329d8ef948..da78c5afab 100644 --- a/tornado-unittests/src/main/java/module-info.java +++ b/tornado-unittests/src/main/java/module-info.java @@ -8,6 +8,7 @@ requires org.junit.platform.launcher; requires jdk.incubator.vector; requires com.microsoft.onnxruntime; + requires junit; exports uk.ac.manchester.tornado.unittests; diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java index 4f3965a1cb..e56ca674de 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/batches/TestBatches.java @@ -24,6 +24,7 @@ import java.util.Random; import java.util.stream.IntStream; +import org.junit.Ignore; import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -38,6 +39,7 @@ import uk.ac.manchester.tornado.api.types.arrays.IntArray; import uk.ac.manchester.tornado.api.types.arrays.LongArray; import uk.ac.manchester.tornado.api.types.arrays.ShortArray; +import uk.ac.manchester.tornado.unittests.common.TornadoNotSupported; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; import uk.ac.manchester.tornado.unittests.tools.Exceptions.UnsupportedConfigurationException; @@ -208,7 +210,7 @@ public void test100MBSmall() { executionPlan.freeDeviceMemory(); } - @Test + @TornadoNotSupported public void test100MBSmallLazy() { long maxAllocMemory = checkMaxHeapAllocationOnDevice(100, MemoryUnit.MB); @@ -276,7 +278,7 @@ public void test100MB() { executionPlan.freeDeviceMemory(); } - @Test + @TornadoNotSupported public void test100MBLazy() { long maxAllocMemory = checkMaxHeapAllocationOnDevice(100, MemoryUnit.MB); @@ -345,7 +347,7 @@ public void test300MB() { executionPlan.freeDeviceMemory(); } - @Test + @TornadoNotSupported public void test300MBLazy() { long maxAllocMemory = checkMaxHeapAllocationOnDevice(300, MemoryUnit.MB); @@ -413,7 +415,7 @@ public void test512MB() { executionPlan.freeDeviceMemory(); } - @Test + @TornadoNotSupported public void test512MBLazy() { long maxAllocMemory = checkMaxHeapAllocationOnDevice(512, MemoryUnit.MB); @@ -863,7 +865,7 @@ public void testBatchNotEven2() { executionPlan.freeDeviceMemory(); } - @Test + @TornadoNotSupported public void testBatchNotEven2Lazy() { checkMaxHeapAllocationOnDevice(64, MemoryUnit.MB); @@ -938,4 +940,4 @@ private enum MemoryUnit { MB, GB, TB } -} \ No newline at end of file +} From c7960a18c61673884012b64a6ea05650e8c7eb07 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 3 Sep 2024 13:51:08 +0300 Subject: [PATCH 10/11] Update JUnit versions in POM files --- pom.xml | 4 ++-- tornado-unittests/pom.xml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index f8a38125e2..b0e6a3b4d6 100644 --- a/pom.xml +++ b/pom.xml @@ -25,8 +25,8 @@ 1.29 jmhbenchmarks tornado-assembly/src/etc/checkstyle.xml - 5.10.2 - 1.9.0 + 5.11.0 + 1.11.0 diff --git a/tornado-unittests/pom.xml b/tornado-unittests/pom.xml index 23aa39b5c4..aab9c9f80b 100644 --- a/tornado-unittests/pom.xml +++ b/tornado-unittests/pom.xml @@ -4,8 +4,8 @@ xmlns="http://maven.apache.org/POM/4.0.0"> 4.0.0 - 5.9.0 - 1.9.0 + 5.11.0 + 1.11.0 tornado From 0214abcc94fd75adf6d32b59301729d87e626626 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 3 Sep 2024 14:41:14 +0300 Subject: [PATCH 11/11] Switch from JUnit 4 to JUnit 5 and refactor asserts --- tornado-unittests/pom.xml | 6 +++ .../api/TestBuildFromByteBuffers.java | 2 +- .../api/TestMemorySegmentsAsType.java | 44 ++++++++++--------- .../compiler/TestCompilerFlagsAPI.java | 2 +- .../memory/MemoryConsumptionTest.java | 2 +- .../TestMultiThreadedExecutionPlans.java | 7 +-- .../unittests/runtime/TestRuntimeAPI.java | 2 +- 7 files changed, 35 insertions(+), 30 deletions(-) diff --git a/tornado-unittests/pom.xml b/tornado-unittests/pom.xml index aab9c9f80b..99ab12bc60 100644 --- a/tornado-unittests/pom.xml +++ b/tornado-unittests/pom.xml @@ -51,5 +51,11 @@ ${junit.jupiter.version} compile + + org.testng + testng + RELEASE + compile + diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestBuildFromByteBuffers.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestBuildFromByteBuffers.java index 6253452dc5..ce8f771f53 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestBuildFromByteBuffers.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestBuildFromByteBuffers.java @@ -28,7 +28,7 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.CharArray; diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestMemorySegmentsAsType.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestMemorySegmentsAsType.java index 54b791691f..a125e31d9f 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestMemorySegmentsAsType.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestMemorySegmentsAsType.java @@ -18,7 +18,15 @@ package uk.ac.manchester.tornado.unittests.api; -import org.junit.Test; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +import org.junit.jupiter.api.Test; + import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; @@ -27,12 +35,6 @@ import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; -import java.lang.foreign.Arena; -import java.lang.foreign.MemorySegment; -import java.lang.foreign.ValueLayout; - -import static java.lang.foreign.ValueLayout.JAVA_INT; - /** *

* How to run. @@ -48,23 +50,25 @@ private static void getMemorySegment(MemorySegment a) { float test = a.getAtIndex(ValueLayout.JAVA_FLOAT, 5); } - @Test(expected = TornadoRuntimeException.class) + @Test public void testMemorySegmentAsInput() throws TornadoExecutionPlanException { - MemorySegment segment; - long segmentByteSize = numElements * ValueLayout.JAVA_FLOAT.byteSize(); + assertThrows(TornadoRuntimeException.class, () -> { + MemorySegment segment; + long segmentByteSize = numElements * ValueLayout.JAVA_FLOAT.byteSize(); - segment = Arena.ofAuto().allocate(segmentByteSize, 1); - segment.setAtIndex(JAVA_INT, 0, numElements); + segment = Arena.ofAuto().allocate(segmentByteSize, 1); + segment.setAtIndex(JAVA_INT, 0, numElements); - TaskGraph taskGraph = new TaskGraph("s0") // - .transferToDevice(DataTransferMode.FIRST_EXECUTION, segment) // - .task("t0", TestMemorySegmentsAsType::getMemorySegment, segment) // - .transferToHost(DataTransferMode.EVERY_EXECUTION, segment); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, segment) // + .task("t0", TestMemorySegmentsAsType::getMemorySegment, segment) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, segment); - ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); - try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { - executionPlan.execute(); - } + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.execute(); + } + }); } } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compiler/TestCompilerFlagsAPI.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compiler/TestCompilerFlagsAPI.java index 0d678c6167..1ed095a1c2 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compiler/TestCompilerFlagsAPI.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compiler/TestCompilerFlagsAPI.java @@ -17,7 +17,7 @@ */ package uk.ac.manchester.tornado.unittests.compiler; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/memory/MemoryConsumptionTest.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/memory/MemoryConsumptionTest.java index c36bcea544..d0bda9763a 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/memory/MemoryConsumptionTest.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/memory/MemoryConsumptionTest.java @@ -19,7 +19,7 @@ import static org.junit.Assert.assertEquals; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/multithreaded/TestMultiThreadedExecutionPlans.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/multithreaded/TestMultiThreadedExecutionPlans.java index 873fd1f122..3ef3c0c492 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/multithreaded/TestMultiThreadedExecutionPlans.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/multithreaded/TestMultiThreadedExecutionPlans.java @@ -17,8 +17,7 @@ */ package uk.ac.manchester.tornado.unittests.multithreaded; -import static org.junit.Assert.fail; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import org.junit.jupiter.api.Test; @@ -210,14 +209,12 @@ public void test03() { try { t1.join(); } catch (InterruptedException e) { - assertTrue(false, "Error"); fail("Error"); } try { t2.join(); } catch (InterruptedException e) { - assertTrue(false, "Error"); fail("Error"); } @@ -249,14 +246,12 @@ public void test04() { try { t1.join(); } catch (InterruptedException e) { - assertTrue(false, "Error"); fail("Error"); } try { t2.join(); } catch (InterruptedException e) { - assertTrue(false, "Error"); fail("Error"); } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/runtime/TestRuntimeAPI.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/runtime/TestRuntimeAPI.java index 250615b3f9..9e0917c3f2 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/runtime/TestRuntimeAPI.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/runtime/TestRuntimeAPI.java @@ -20,7 +20,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; -import org.junit.Test; +import org.junit.jupiter.api.Test; import uk.ac.manchester.tornado.api.TornadoBackend; import uk.ac.manchester.tornado.api.TornadoRuntime;