Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ testing {
implementation "org.apache.sshd:sshd-sftp:$sshdVersion"
implementation "org.apache.sshd:sshd-scp:$sshdVersion"
implementation "ch.qos.logback:logback-classic:1.5.18"
implementation 'org.glassfish.grizzly:grizzly-http-server:3.0.1'
}

targets {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import com.hierynomus.sshj.test.HttpServer;
import com.hierynomus.sshj.test.SshServerExtension;
import com.hierynomus.sshj.test.util.FileUtil;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
import net.schmizz.sshj.connection.channel.direct.Parameters;
Expand All @@ -29,35 +28,30 @@

import java.io.*;
import java.net.*;
import java.nio.file.Files;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class LocalPortForwarderTest {
private static final String LOCALHOST_URL = "http://127.0.0.1:8080";

@RegisterExtension
public SshServerExtension fixture = new SshServerExtension();

@RegisterExtension
public HttpServer httpServer = new HttpServer();

@BeforeEach
public void setUp() throws IOException {
public void setUp() {
fixture.getServer().setForwardingFilter(new AcceptAllForwardingFilter());
File file = Files.createFile(httpServer.getDocRoot().toPath().resolve("index.html")).toFile();
FileUtil.writeToFile(file, "<html><head/><body><h1>Hi!</h1></body></html>");
}

@Test
public void shouldHaveWorkingHttpServer() throws IOException {
assertEquals(200, httpGet());
assertEquals(HttpURLConnection.HTTP_NOT_FOUND, httpGet());
}

@Test
public void shouldHaveHttpServerThatClosesConnectionAfterResponse() throws IOException {
// Just to check that the test server does close connections before we try through the forwarder...
httpGetAndAssertConnectionClosedByServer(8080);
httpGetAndAssertConnectionClosedByServer(httpServer.getServerUrl().getPort());
}

@Test
Expand All @@ -68,7 +62,8 @@ public void shouldCloseConnectionWhenRemoteServerClosesConnection() throws IOExc
ServerSocket serverSocket = new ServerSocket();
serverSocket.setReuseAddress(true);
serverSocket.bind(new InetSocketAddress("0.0.0.0", 12345));
LocalPortForwarder localPortForwarder = sshClient.newLocalPortForwarder(new Parameters("0.0.0.0", 12345, "localhost", 8080), serverSocket);
final int serverPort = httpServer.getServerUrl().getPort();
LocalPortForwarder localPortForwarder = sshClient.newLocalPortForwarder(new Parameters("0.0.0.0", 12345, "localhost", serverPort), serverSocket);
new Thread(() -> {
try {
localPortForwarder.listen();
Expand All @@ -90,7 +85,7 @@ public static void httpGetAndAssertConnectionClosedByServer(int port) throws IOE
// It returns 400 Bad Request because it's missing a bunch of info, but the HTTP response doesn't matter, we just want to test the connection closing.
OutputStream outputStream = socket.getOutputStream();
PrintWriter writer = new PrintWriter(outputStream);
writer.println("GET / HTTP/1.1");
writer.println("GET / HTTP/1.1\r\n");
writer.println("");
writer.flush();

Expand All @@ -111,7 +106,7 @@ public static void httpGetAndAssertConnectionClosedByServer(int port) throws IOE
}

private int httpGet() throws IOException {
final URL url = new URL(LOCALHOST_URL);
final URL url = httpServer.getServerUrl().toURL();
final HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection();
urlConnection.setConnectTimeout(3000);
urlConnection.setRequestMethod("GET");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import com.hierynomus.sshj.test.HttpServer;
import com.hierynomus.sshj.test.SshServerExtension;
import com.hierynomus.sshj.test.util.FileUtil;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder;
Expand All @@ -27,20 +26,18 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import java.io.File;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URL;
import java.nio.file.Files;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class RemotePortForwarderTest {
private static final PortRange RANGE = new PortRange(9000, 9999);
private static final String LOCALHOST = "127.0.0.1";
private static final String LOCALHOST_URL_FORMAT = "http://127.0.0.1:%d";
private static final InetSocketAddress HTTP_SERVER_SOCKET_ADDR = new InetSocketAddress(LOCALHOST, 8080);
private static final String URL_FORMAT = "http://%s:%d";

@RegisterExtension
public SshServerExtension fixture = new SshServerExtension();
Expand All @@ -49,21 +46,21 @@ public class RemotePortForwarderTest {
public HttpServer httpServer = new HttpServer();

@BeforeEach
public void setUp() throws IOException {
public void setUp() {
fixture.getServer().setForwardingFilter(new AcceptAllForwardingFilter());
File file = Files.createFile(httpServer.getDocRoot().toPath().resolve("index.html")).toFile();
FileUtil.writeToFile(file, "<html><head/><body><h1>Hi!</h1></body></html>");
}

@Test
public void shouldHaveWorkingHttpServer() throws IOException {
assertEquals(200, httpGet(8080));
final URI serverUrl = httpServer.getServerUrl();

assertEquals(HttpURLConnection.HTTP_NOT_FOUND, httpGet(serverUrl.getHost(), serverUrl.getPort()));
}

@Test
public void shouldDynamicallyForwardPortForLocalhost() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", new SinglePort(0));
RemotePortForwarder.Forward bind = forwardPort(sshClient, LOCALHOST, new SinglePort(0));
assertHttpGetSuccess(bind);
}

Expand All @@ -84,7 +81,7 @@ public void shouldDynamicallyForwardPortForAllProtocols() throws IOException {
@Test
public void shouldForwardPortForLocalhost() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", RANGE);
RemotePortForwarder.Forward bind = forwardPort(sshClient, LOCALHOST, RANGE);
assertHttpGetSuccess(bind);
}

Expand All @@ -103,17 +100,22 @@ public void shouldForwardPortForAllProtocols() throws IOException {
}

private void assertHttpGetSuccess(final RemotePortForwarder.Forward bind) throws IOException {
assertEquals(200, httpGet(bind.getPort()));
final String bindAddress = bind.getAddress();
final String address = bindAddress.isEmpty() ? LOCALHOST : bindAddress;
final int port = bind.getPort();
assertEquals(HttpURLConnection.HTTP_NOT_FOUND, httpGet(address, port));
}

private RemotePortForwarder.Forward forwardPort(SSHClient sshClient, String address, PortRange portRange) throws IOException {
while (true) {
final URI serverUrl = httpServer.getServerUrl();
final InetSocketAddress serverAddress = new InetSocketAddress(serverUrl.getHost(), serverUrl.getPort());
try {
return sshClient.getRemotePortForwarder().bind(
// where the server should listen
new RemotePortForwarder.Forward(address, portRange.nextPort()),
// what we do with incoming connections that are forwarded to us
new SocketForwardingConnectListener(HTTP_SERVER_SOCKET_ADDR));
new SocketForwardingConnectListener(serverAddress));
} catch (ConnectionException ce) {
if (!portRange.hasNext()) {
throw ce;
Expand All @@ -122,8 +124,8 @@ private RemotePortForwarder.Forward forwardPort(SSHClient sshClient, String addr
}
}

private int httpGet(final int port) throws IOException {
final URL url = new URL(String.format(LOCALHOST_URL_FORMAT, port));
private int httpGet(final String address, final int port) throws IOException {
final URL url = new URL(String.format(URL_FORMAT, address, port));
final HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection();
urlConnection.setConnectTimeout(3000);
urlConnection.setRequestMethod("GET");
Expand Down
34 changes: 14 additions & 20 deletions src/test/java/com/hierynomus/sshj/test/HttpServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,36 @@
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;

import java.io.File;
import java.nio.file.Files;
import java.net.InetSocketAddress;
import java.net.URI;

/**
* Can be used to setup a test HTTP server
*/
public class HttpServer implements BeforeEachCallback, AfterEachCallback {

private org.glassfish.grizzly.http.server.HttpServer httpServer;
private static final String BIND_ADDRESS = "127.0.0.1";


private File docRoot ;
private com.sun.net.httpserver.HttpServer httpServer;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jikes... com.sun.net... As far as I remember, wasn't it an antipattern to depend on anything in com.sun as this is not guaranteed to be present in other JVMs.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scratch that, I've read up on it... Seems that I'm confusing sun.* and com.sun.*... Nothing to see here, thanks for the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing and merging @hierynomus! Yes, the package name for the JDK HttpServer causes a natural double take, so thanks for giving it a closer look, as it is part of the public JDK.


@Override
public void afterEach(ExtensionContext context) throws Exception {
try {
httpServer.shutdownNow();
} catch (Exception e) {}
public void afterEach(ExtensionContext context) {
try {
docRoot.delete();
} catch (Exception e) {}

httpServer.stop(0);
} catch (Exception ignored) {}
}

@Override
public void beforeEach(ExtensionContext context) throws Exception {
docRoot = Files.createTempDirectory("sshj").toFile();
httpServer = org.glassfish.grizzly.http.server.HttpServer.createSimpleServer(docRoot.getAbsolutePath());
httpServer = com.sun.net.httpserver.HttpServer.create();
final InetSocketAddress socketAddress = new InetSocketAddress(BIND_ADDRESS, 0);
httpServer.bind(socketAddress, 10);
httpServer.start();
}

public org.glassfish.grizzly.http.server.HttpServer getHttpServer() {
return httpServer;
}

public File getDocRoot() {
return docRoot;
public URI getServerUrl() {
final InetSocketAddress bindAddress = httpServer.getAddress();
final String serverUrl = String.format("http://%s:%d", BIND_ADDRESS, bindAddress.getPort());
return URI.create(serverUrl);
}
}