Skip to content

Commit

Permalink
Merge pull request #32 from WSE-research/addingTests
Browse files Browse the repository at this point in the history
Refactoring and Testing
  • Loading branch information
dschiese authored May 27, 2024
2 parents 3cbc6d3 + eb39b1b commit 2ddc309
Show file tree
Hide file tree
Showing 23 changed files with 265 additions and 263 deletions.
22 changes: 11 additions & 11 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.openlinksw</groupId>
<artifactId>virt_jena4_4</artifactId>
<version>1.43</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.openlinksw/virtjdbc4 -->
<dependency>
<groupId>com.openlinksw</groupId>
<artifactId>virtjdbc4_3</artifactId>
<version>3.123</version>
</dependency>
<dependency>
<groupId>com.knuddels</groupId>
<artifactId>jtokkit</artifactId>
Expand Down Expand Up @@ -83,12 +94,6 @@
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
</dependency>
<dependency>
<groupId>org.apache.jena</groupId>
<artifactId>apache-jena-libs</artifactId>
<version>4.8.0</version>
<type>pom</type>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
Expand All @@ -100,11 +105,6 @@
<artifactId>spring-webflux</artifactId>
<version>6.0.11</version>
</dependency>
<dependency>
<groupId>org.apache.jena</groupId>
<artifactId>jena-arq</artifactId>
<version>4.8.0</version>
</dependency>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit-platform</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

@RestController
public class AutomatedTestController {

private final String authToken = ""; // Todo

@Autowired
private AutomatedTestingService automatedTestingService;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ public class ClientController {
@Autowired
private ClientService clientService;

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import java.io.IOException;
import java.util.Map;

@RestController
@ControllerAdvice
Expand All @@ -23,9 +21,6 @@ public class ExplanationController {
@Autowired
private ExplanationService explanationService;

@Value("${gpt.request.auth}")
private String gptRequestAuthToken;

/**
* Computes the explanations for (currently) the output data for a specific graph and/or component
*
Expand Down Expand Up @@ -106,7 +101,7 @@ public ResponseEntity<?> getOutputExplanation(
@PathVariable String componentURI,
@RequestHeader(value = "accept", required = false) String acceptHeader) {
try {
String explanationInFormattedString = explanationService.getTemplateComponentExplanation(graphURI, componentURI, null);
String explanationInFormattedString = explanationService.getTemplateComponentExplanation(graphURI, componentURI, acceptHeader);
return new ResponseEntity<>(explanationInFormattedString, HttpStatus.OK);
} catch (Exception e) {
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import com.wse.qanaryexplanationservice.helper.pojos.AutomatedTests.QanaryRequestPojos.QanaryRequestObject;
import com.wse.qanaryexplanationservice.helper.pojos.AutomatedTests.QanaryRequestPojos.QanaryResponseObject;
import org.apache.jena.query.QueryExecution;
import org.apache.jena.query.ResultSet;
import org.apache.jena.rdfconnection.RDFConnection;
import org.apache.jena.query.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
Expand All @@ -14,65 +12,66 @@
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.netty.http.client.HttpClient;
import virtuoso.jena.driver.VirtGraph;
import virtuoso.jena.driver.VirtuosoQueryExecution;
import virtuoso.jena.driver.VirtuosoQueryExecutionFactory;

import java.time.Duration;

/**
* This class provides different request methods against the Qanary pipeline or the underlying triplestore
*/
@Repository
public class QanaryRepository {

private final static WebClient webClient = WebClient.builder().clientConnector(new ReactorClientHttpConnector(HttpClient.create().responseTimeout(Duration.ofSeconds(60)))).build();
private final static Logger logger = LoggerFactory.getLogger(QanaryRequestObject.class);
private static String QANARY_PIPELINE_HOST;
private static int QANARY_PIPELINE_PORT;
private static RDFConnection connection;
private static String sparqlendpoint;
private final WebClient webClient = WebClient.builder().clientConnector(new ReactorClientHttpConnector(HttpClient.create().responseTimeout(Duration.ofSeconds(60)))).build();
private final Logger logger = LoggerFactory.getLogger(QanaryRequestObject.class);
@Value("${qanary.pipeline.host}")
private String QANARY_PIPELINE_HOST;
@Value("${qanary.pipeline.port}")
private int QANARY_PIPELINE_PORT;
private VirtGraph connection;
@Value("${virtuoso.triplestore.endpoint}")
private String virtuosoEndpoint;
@Value("${virtuoso.triplestore.username}")
private String virtuosoUser;
@Value("${virtuoso.triplestore.password}")
private String virtuosoPassword;
@Value("${qanary.pipeline.host}")
private String qanaryHost;
@Value("${qanary.pipeline.port}")
private int qanaryPort;

public QanaryRepository() {
}

public static RDFConnection getConnection() {
return connection;
}

public static QanaryResponseObject executeQanaryPipeline(QanaryRequestObject qanaryRequestObject) {
public QanaryResponseObject executeQanaryPipeline(QanaryRequestObject qanaryRequestObject) {

MultiValueMap<String, String> multiValueMap = new LinkedMultiValueMap();
multiValueMap.add("question", qanaryRequestObject.getQuestion());
multiValueMap.addAll(qanaryRequestObject.getComponentListAsMap());

QanaryResponseObject responseObject = webClient.post().uri(uriBuilder -> uriBuilder // TODO: use new endpoint for question answering
return webClient.post().uri(uriBuilder -> uriBuilder // TODO: use new endpoint for question answering
.scheme("http").host(QANARY_PIPELINE_HOST).port(QANARY_PIPELINE_PORT).path("/startquestionansweringwithtextquestion")
.queryParams(multiValueMap)
.build())
.retrieve()
.bodyToMono(QanaryResponseObject.class)
.block();

logger.info("Response Object: {}", responseObject);

return responseObject;
}

public static ResultSet selectWithResultSet(String sparql) {
logger.warn("Executing with SPARQL endpoint {}", sparqlendpoint);
QueryExecution queryExecution = connection.query(sparql);
return queryExecution.execSelect();
}

@Value("${qanary.pipeline.host}")
public void setQanaryPipelineHost(String qanaryPipelineHost) {
QANARY_PIPELINE_HOST = qanaryPipelineHost;
}

@Value("${qanary.pipeline.port}")
public void setQanaryPipelinePort(int qanaryPipelinePort) {
QANARY_PIPELINE_PORT = qanaryPipelinePort;
public ResultSet selectWithResultSet(String sparql) throws QueryException {
if (connection == null)
initConnection();
Query query = QueryFactory.create(sparql);
VirtuosoQueryExecution vqe = VirtuosoQueryExecutionFactory.create(query, this.connection);
ResultSetRewindable results = ResultSetFactory.makeRewindable(vqe.execSelect());
return results;
}

@Value("${sparql.endpoint}")
public void setVirtuosoEndpoint(String sparqlEndpoint) {
sparqlendpoint = sparqlEndpoint;
connection = RDFConnection.connect(sparqlEndpoint);
public void initConnection() {
logger.info("Init connection for Qanary repository: {}", this.virtuosoEndpoint);
connection = new VirtGraph(this.virtuosoEndpoint, this.virtuosoUser, this.virtuosoPassword);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Service;

Expand All @@ -34,13 +33,13 @@ public class AutomatedTestingService {
private final Logger logger = LoggerFactory.getLogger(AutomatedTestingService.class);
// stores the correct template for different x-shot approaches
private final Map<Integer, String> exampleCountAndTemplate = new HashMap<>() {{
put(1, "/testtemplates/oneshot");
put(2, "/testtemplates/twoshot");
put(3, "/testtemplates/threeshot");
put(1, "/prompt_templates/outputdata/oneshot");
put(2, "/prompt_templates/outputdata/twoshot");
put(3, "/prompt_templates/outputdata/threeshot");
}};


private final Random random;
@Value("${questionId.replacement}")
private String QUESTION_ID_REPLACEMENT;
@Autowired
private GenerativeExplanations generativeExplanations;
@Value("${explanations.dataset.limit}")
Expand All @@ -51,7 +50,7 @@ public class AutomatedTestingService {
private ExplanationDataService explanationDataService;

// CONSTRUCTOR(s)
public AutomatedTestingService(Environment environment) {
public AutomatedTestingService() {
this.random = new Random();
}

Expand All @@ -61,8 +60,7 @@ public String selectComponent(AnnotationType annotationType, AutomatedTest autom

// Case if example is null, e.g. when the testing data is calculated
if (example == null || !example.getUniqueComponent()) {
int selectedComponentAsInt = random.nextInt(componentsList.length);
return componentsList[selectedComponentAsInt];
return selectRandomComponentFromComponentList(componentsList);
}
// Case if component should be unique in the whole test-case
else {
Expand All @@ -76,12 +74,18 @@ public String selectComponent(AnnotationType annotationType, AutomatedTest autom
componentList.remove(rnd); // Remove visited item -> list.size()-1 -> prevent infinite loop
} while (usedComponentsInTest.contains(component));
} catch (Exception e) {
throw new RuntimeException("There is no other unique and unused component for type " + annotationType.name());
logger.error("There is no other unique and unused component for type {}, select other component...", annotationType.name());
return selectRandomComponentFromComponentList(componentsList);
}
return component;
}
}

public String selectRandomComponentFromComponentList(String[] componentsList) {
int selectedComponentAsInt = random.nextInt(componentsList.length);
return componentsList[selectedComponentAsInt];
}

public ArrayList<String> fetchUsedComponents(AutomatedTest automatedTest) {
ArrayList<String> list = new ArrayList<>();
list.add(automatedTest.getTestData().getUsedComponent()); // Adds test-data component
Expand All @@ -90,7 +94,6 @@ public ArrayList<String> fetchUsedComponents(AutomatedTest automatedTest) {
for (TestDataObject item : listExamples) { // Adds every currently known component to the list
list.add(item.getUsedComponent());
}

return list;
}

Expand Down Expand Up @@ -165,15 +168,15 @@ public TestDataObject computeSingleTestObject(AnnotationType givenAnnotationType
logger.info("Execute Qanary pipeline");
QanaryResponseObject qanaryResponse = generativeExplanations.executeQanaryPipeline(question, componentListForQanaryPipeline);
String graphURI = qanaryResponse.getOutGraph();
String questionID = qanaryResponse.getQuestion().replace("http://localhost:8080/question/stored-question__text_", "questionID:");
String questionID = qanaryResponse.getQuestion().replace(QUESTION_ID_REPLACEMENT + "/question/stored-question__text_", "questionID:");

// Create dataset
logger.info("Create dataset");
String dataset = generativeExplanations.createDataset(selectedComponent, graphURI, givenAnnotationType.name());

// Create Explanation for selected component
logger.info("Create explanation");
String explanation = generativeExplanationsService.getTemplateExplanation(graphURI, selectedComponent);
String explanation = generativeExplanationsService.getTemplateExplanation(graphURI, selectedComponent, "en");
return new TestDataObject(
givenAnnotationType,
givenAnnotationType.ordinal(),
Expand Down Expand Up @@ -231,7 +234,6 @@ public String createTestWorkflow(AutomatedTestRequestBody requestBody, boolean d
AutomatedTest test;

while (jsonArray.length() < requestBody.getRuns()) {
logger.info("CURRENT RUN: {}", jsonArray.length() + 1);
test = createTest(requestBody); // null if not successful
if (test != null) {
if (doGptCall) {
Expand All @@ -240,7 +242,7 @@ public String createTestWorkflow(AutomatedTestRequestBody requestBody, boolean d
}
JSONObject finishedTest = new JSONObject(test);
jsonArray.put(finishedTest); // Add test to Json-Array
explanationDataService.insertDataset(test);
explanationDataService.insertDataset(test, doGptCall);
} else
logger.info("Skipped run due to null-ResultSet");
}
Expand All @@ -257,8 +259,6 @@ public Integer selectComponentAsInt(AnnotationType annotationType) {

}

// TODO: don't retry already used combinations




Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,11 @@ public ExplanationDataService() {
*
* @param automatedTest Object which will be parsed to a model and inserted to the triplestore
*/
public void insertDataset(AutomatedTest automatedTest) {
public void insertDataset(AutomatedTest automatedTest, boolean isGptCall) {

String uuid = UUID.randomUUID().toString();
List<Statement> statementList = createSpecificStatementList(automatedTest, uuid);
List<Statement> statementList = createSpecificStatementList(automatedTest, uuid, isGptCall);
model.add(statementList);
// TODO: Move connection details (e.g. application.properties / ...)
VirtModel virtModel = VirtModel.openDatabaseModel("urn:aex:" + uuid, VIRTUOSO_TRIPLESTORE_ENDPOINT, VIRTUOSO_TRIPLESTORE_USERNAME, VIRTUOSO_TRIPLESTORE_PASSWORD);
virtModel.add(model); // TODO: Auslagern des VirtModel Aufrufs
virtModel.close();
Expand All @@ -98,14 +97,14 @@ public void insertDataset(AutomatedTest automatedTest) {
* @param uuid graph identifier
* @return List of Statements (= triples)
*/
public List<Statement> createSpecificStatementList(AutomatedTest automatedTest, String uuid) {
public List<Statement> createSpecificStatementList(AutomatedTest automatedTest, String uuid, boolean isGptCall) {
List<Statement> statementList = new ArrayList<>();
Resource experimentId = model.createResource(uuid);
logger.info("Experiment ID: {}", experimentId);


statementList.add(ResourceFactory.createStatement(experimentId, prompt, ResourceFactory.createPlainLiteral(automatedTest.getPrompt())));
//statementList.add(ResourceFactory.createStatement(experimentId, gptExplanation, ResourceFactory.createPlainLiteral(automatedTest.getGptExplanation())));
if (isGptCall)
statementList.add(ResourceFactory.createStatement(experimentId, gptExplanation, ResourceFactory.createPlainLiteral(automatedTest.getGptExplanation())));
statementList.add(ResourceFactory.createStatement(experimentId, testData, setUpTestObject(automatedTest.getTestData())));
statementList.add(ResourceFactory.createStatement(experimentId, exampleData, setupExampleData(automatedTest.getExampleData())));

Expand Down
Loading

0 comments on commit 2ddc309

Please sign in to comment.