Skip to content

Commit

Permalink
Merge pull request #9 from Together-Java/feat/use-docker-java-client
Browse files Browse the repository at this point in the history
Feat/use docker java client
  • Loading branch information
Alathreon authored May 14, 2024
2 parents 7bb9d99 + 7a927d7 commit 8fcdb40
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 39 deletions.
3 changes: 3 additions & 0 deletions JShellAPI/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ repositories {
dependencies {
implementation project(':JShellWrapper')
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'com.github.docker-java:docker-java-transport-httpclient5:3.3.4'
implementation 'com.github.docker-java:docker-java-core:3.3.4'

testImplementation 'org.springframework.boot:spring-boot-starter-test'
annotationProcessor "org.springframework.boot:spring-boot-configuration-processor"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ public record Config(
long maxAliveSessions,
int dockerMaxRamMegaBytes,
double dockerCPUsUsage,
long schedulerSessionKillScanRateSeconds) {
long schedulerSessionKillScanRateSeconds,
long dockerResponseTimeout,
long dockerConnectionTimeout) {
public Config {
if(regularSessionTimeoutSeconds <= 0) throw new RuntimeException("Invalid value " + regularSessionTimeoutSeconds);
if(oneTimeSessionTimeoutSeconds <= 0) throw new RuntimeException("Invalid value " + oneTimeSessionTimeoutSeconds);
Expand All @@ -21,5 +23,7 @@ public record Config(
if(dockerMaxRamMegaBytes <= 0) throw new RuntimeException("Invalid value " + dockerMaxRamMegaBytes);
if(dockerCPUsUsage <= 0) throw new RuntimeException("Invalid value " + dockerCPUsUsage);
if(schedulerSessionKillScanRateSeconds <= 0) throw new RuntimeException("Invalid value " + schedulerSessionKillScanRateSeconds);
if(dockerResponseTimeout <= 0) throw new RuntimeException("Invalid value " + dockerResponseTimeout);
if(dockerConnectionTimeout <= 0) throw new RuntimeException("Invalid value " + dockerConnectionTimeout);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package org.togetherjava.jshellapi.service;

import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.async.ResultCallback;
import com.github.dockerjava.api.command.PullImageResultCallback;
import com.github.dockerjava.api.model.*;
import com.github.dockerjava.core.DefaultDockerClientConfig;
import com.github.dockerjava.core.DockerClientImpl;
import com.github.dockerjava.httpclient5.ApacheDockerHttpClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.togetherjava.jshellapi.Config;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.TimeUnit;

@Service
public class DockerService implements DisposableBean {
private static final Logger LOGGER = LoggerFactory.getLogger(DockerService.class);
private static final String WORKER_LABEL = "jshell-api-worker";
private static final UUID WORKER_UNIQUE_ID = UUID.randomUUID();

private final DockerClient client;

public DockerService(Config config) {
DefaultDockerClientConfig clientConfig = DefaultDockerClientConfig.createDefaultConfigBuilder().build();
ApacheDockerHttpClient httpClient = new ApacheDockerHttpClient.Builder()
.dockerHost(clientConfig.getDockerHost())
.sslConfig(clientConfig.getSSLConfig())
.responseTimeout(Duration.ofSeconds(config.dockerResponseTimeout()))
.connectionTimeout(Duration.ofSeconds(config.dockerConnectionTimeout()))
.build();
this.client = DockerClientImpl.getInstance(clientConfig, httpClient);

cleanupLeftovers(WORKER_UNIQUE_ID);
}

private void cleanupLeftovers(UUID currentId) {
for (Container container : client.listContainersCmd().withLabelFilter(Set.of(WORKER_LABEL)).exec()) {
String containerHumanName = container.getId() + " " + Arrays.toString(container.getNames());
LOGGER.info("Found worker container '{}'", containerHumanName);
if (!container.getLabels().get(WORKER_LABEL).equals(currentId.toString())) {
LOGGER.info("Killing container '{}'", containerHumanName);
client.killContainerCmd(container.getId()).exec();
}
}
}

public String spawnContainer(
long maxMemoryMegs, long cpus, String name, Duration evalTimeout, long sysoutLimit
) throws InterruptedException {
String imageName = "togetherjava.org:5001/togetherjava/jshellwrapper";
boolean presentLocally = client.listImagesCmd()
.withFilter("reference", List.of(imageName))
.exec()
.stream()
.flatMap(it -> Arrays.stream(it.getRepoTags()))
.anyMatch(it -> it.endsWith(":master"));

if (!presentLocally) {
client.pullImageCmd(imageName)
.withTag("master")
.exec(new PullImageResultCallback())
.awaitCompletion(5, TimeUnit.MINUTES);
}

return client.createContainerCmd(
imageName + ":master"
)
.withHostConfig(
HostConfig.newHostConfig()
.withAutoRemove(true)
.withInit(true)
.withCapDrop(Capability.ALL)
.withNetworkMode("none")
.withPidsLimit(2000L)
.withReadonlyRootfs(true)
.withMemory(maxMemoryMegs * 1024 * 1024)
.withCpuCount(cpus)
)
.withStdinOpen(true)
.withAttachStdin(true)
.withAttachStderr(true)
.withAttachStdout(true)
.withEnv("evalTimeoutSeconds=" + evalTimeout.toSeconds(), "sysOutCharLimit=" + sysoutLimit)
.withLabels(Map.of(WORKER_LABEL, WORKER_UNIQUE_ID.toString()))
.withName(name)
.exec()
.getId();
}

public InputStream startAndAttachToContainer(String containerId, InputStream stdin) throws IOException {
PipedInputStream pipeIn = new PipedInputStream();
PipedOutputStream pipeOut = new PipedOutputStream(pipeIn);

client.attachContainerCmd(containerId)
.withLogs(true)
.withFollowStream(true)
.withStdOut(true)
.withStdErr(true)
.withStdIn(stdin)
.exec(new ResultCallback.Adapter<>() {
@Override
public void onNext(Frame object) {
try {
String payloadString = new String(object.getPayload(), StandardCharsets.UTF_8);
if (object.getStreamType() == StreamType.STDOUT) {
pipeOut.write(object.getPayload());
} else {
LOGGER.warn(
"Received STDERR from container {}: {}",
containerId,
payloadString
);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
});

client.startContainerCmd(containerId).exec();
return pipeIn;
}

public void killContainerByName(String name) {
for (Container container : client.listContainersCmd().withNameFilter(Set.of(name)).exec()) {
client.killContainerCmd(container.getId()).exec();
}
}

public boolean isDead(String containerName) {
return client.listContainersCmd().withNameFilter(Set.of(containerName)).exec().isEmpty();
}

@Override
public void destroy() throws Exception {
LOGGER.info("destroy() called. Destroying all containers...");
cleanupLeftovers(UUID.randomUUID());
client.close();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
import org.togetherjava.jshellapi.dto.*;
import org.togetherjava.jshellapi.exceptions.DockerException;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -18,16 +16,17 @@
public class JShellService implements Closeable {
private final JShellSessionService sessionService;
private final String id;
private Process process;
private final BufferedWriter writer;
private final BufferedReader reader;

private Instant lastTimeoutUpdate;
private final long timeout;
private final boolean renewable;
private boolean doingOperation;
private final DockerService dockerService;

public JShellService(JShellSessionService sessionService, String id, long timeout, boolean renewable, long evalTimeout, int sysOutCharLimit, int maxMemory, double cpus, String startupScript) throws DockerException {
public JShellService(DockerService dockerService, JShellSessionService sessionService, String id, long timeout, boolean renewable, long evalTimeout, int sysOutCharLimit, int maxMemory, double cpus, String startupScript) throws DockerException {
this.dockerService = dockerService;
this.sessionService = sessionService;
this.id = id;
this.timeout = timeout;
Expand All @@ -39,30 +38,23 @@ public JShellService(JShellSessionService sessionService, String id, long timeou
Files.createDirectories(errorLogs.getParent());
Files.createFile(errorLogs);
}
process = new ProcessBuilder(
"docker",
"run",
"--rm",
"-i",
"--init",
"--cap-drop=ALL",
"--network=none",
"--pids-limit=2000",
"--read-only",
"--memory=" + maxMemory + "m",
"--cpus=" + cpus,
"--name", containerName(),
"-e", "\"evalTimeoutSeconds=%d\"".formatted(evalTimeout),
"-e", "\"sysOutCharLimit=%d\"".formatted(sysOutCharLimit),
"togetherjava.org:5001/togetherjava/jshellwrapper:master")
.directory(new File(".."))
.redirectError(errorLogs.toFile())
.start();
writer = process.outputWriter();
reader = process.inputReader();
String containerId = dockerService.spawnContainer(
maxMemory,
(long) Math.ceil(cpus),
containerName(),
Duration.ofSeconds(evalTimeout),
sysOutCharLimit
);
PipedInputStream containerInput = new PipedInputStream();
this.writer = new BufferedWriter(new OutputStreamWriter(new PipedOutputStream(containerInput)));
InputStream containerOutput = dockerService.startAndAttachToContainer(
containerId,
containerInput
);
reader = new BufferedReader(new InputStreamReader(containerOutput));
writer.write(sanitize(startupScript));
writer.newLine();
} catch (IOException e) {
} catch (IOException | InterruptedException e) {
throw new DockerException(e);
}
this.doingOperation = false;
Expand All @@ -73,6 +65,10 @@ public Optional<JShellResult> eval(String code) throws DockerException {
return Optional.empty();
}
}
if (isClosed()) {
close();
return Optional.empty();
}
updateLastTimeout();
if(!code.endsWith("\n")) code += '\n';
try {
Expand All @@ -86,7 +82,7 @@ public Optional<JShellResult> eval(String code) throws DockerException {
checkContainerOK();

return Optional.of(readResult());
} catch (IOException | NumberFormatException ex) {
} catch (DockerException | IOException | NumberFormatException ex) {
close();
throw new DockerException(ex);
} finally {
Expand Down Expand Up @@ -185,27 +181,22 @@ public String id() {

@Override
public void close() {
process.destroyForcibly();
try {
try {
writer.close();
} finally {
reader.close();
}
new ProcessBuilder("docker", "kill", containerName())
.directory(new File(".."))
.start()
.waitFor();
} catch(IOException | InterruptedException ex) {
dockerService.killContainerByName(containerName());
} catch(IOException ex) {
throw new RuntimeException(ex);
}
process = null;
sessionService.notifyDeath(id);
}

@Override
public boolean isClosed() {
return process == null;
return dockerService.isDead(containerName());
}

private void updateLastTimeout() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ public class JShellSessionService {
private Config config;
private StartupScriptsService startupScriptsService;
private ScheduledExecutorService scheduler;
private DockerService dockerService;
private final Map<String, JShellService> jshellSessions = new HashMap<>();

private void initScheduler() {
scheduler = Executors.newSingleThreadScheduledExecutor();
scheduler.scheduleAtFixedRate(() -> {
Expand Down Expand Up @@ -74,6 +76,7 @@ private synchronized JShellService createSession(String id, long sessionTimeout,
throw new ResponseStatusException(HttpStatus.TOO_MANY_REQUESTS, "Too many sessions, try again later :(.");
}
JShellService service = new JShellService(
dockerService,
this,
id,
sessionTimeout,
Expand All @@ -97,4 +100,9 @@ public void setConfig(Config config) {
public void setStartupScriptsService(StartupScriptsService startupScriptsService) {
this.startupScriptsService = startupScriptsService;
}

@Autowired
public void setDockerService(DockerService dockerService) {
this.dockerService = dockerService;
}
}
6 changes: 5 additions & 1 deletion JShellAPI/src/main/resources/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ jshellapi.dockerMaxRamMegaBytes=100
jshellapi.dockerCPUsUsage=0.5

# Internal config
jshellapi.schedulerSessionKillScanRateSeconds=60
jshellapi.schedulerSessionKillScanRateSeconds=60

# Docker service config
jshellapi.dockerResponseTimeout=60
jshellapi.dockerConnectionTimeout=60

0 comments on commit 8fcdb40

Please sign in to comment.