Skip to content

Commit

Permalink
Refactored many services, centralized methods, adjusted tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dschiese committed May 27, 2024
1 parent d2d316b commit a54b906
Show file tree
Hide file tree
Showing 21 changed files with 225 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

@RestController
public class AutomatedTestController {

private final String authToken = ""; // Todo

@Autowired
private AutomatedTestingService automatedTestingService;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ public class ClientController {
@Autowired
private ClientService clientService;

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import java.io.IOException;
import java.util.Map;

@RestController
@ControllerAdvice
Expand All @@ -23,9 +21,6 @@ public class ExplanationController {
@Autowired
private ExplanationService explanationService;

@Value("${gpt.request.auth}")
private String gptRequestAuthToken;

/**
* Computes the explanations for (currently) the output data for a specific graph and/or component
*
Expand Down Expand Up @@ -106,7 +101,7 @@ public ResponseEntity<?> getOutputExplanation(
@PathVariable String componentURI,
@RequestHeader(value = "accept", required = false) String acceptHeader) {
try {
String explanationInFormattedString = explanationService.getTemplateComponentExplanation(graphURI, componentURI, null);
String explanationInFormattedString = explanationService.getTemplateComponentExplanation(graphURI, componentURI, acceptHeader);
return new ResponseEntity<>(explanationInFormattedString, HttpStatus.OK);
} catch (Exception e) {
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import com.wse.qanaryexplanationservice.helper.pojos.AutomatedTests.QanaryRequestPojos.QanaryRequestObject;
import com.wse.qanaryexplanationservice.helper.pojos.AutomatedTests.QanaryRequestPojos.QanaryResponseObject;
import org.apache.jena.query.QueryExecution;
import org.apache.jena.query.ResultSet;
import org.apache.jena.rdfconnection.RDFConnection;
import org.apache.jena.query.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
Expand All @@ -14,65 +12,61 @@
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.netty.http.client.HttpClient;
import virtuoso.jena.driver.VirtGraph;
import virtuoso.jena.driver.VirtuosoQueryExecution;
import virtuoso.jena.driver.VirtuosoQueryExecutionFactory;

import java.time.Duration;

/**
* This class provides different request methods against the Qanary pipeline or the underlying triplestore
*/
@Repository
public class QanaryRepository {

private final static WebClient webClient = WebClient.builder().clientConnector(new ReactorClientHttpConnector(HttpClient.create().responseTimeout(Duration.ofSeconds(60)))).build();
private final static Logger logger = LoggerFactory.getLogger(QanaryRequestObject.class);
private static String QANARY_PIPELINE_HOST;
private static int QANARY_PIPELINE_PORT;
private static RDFConnection connection;
private static String sparqlendpoint;

public QanaryRepository() {
}

public static RDFConnection getConnection() {
return connection;
private final WebClient webClient = WebClient.builder().clientConnector(new ReactorClientHttpConnector(HttpClient.create().responseTimeout(Duration.ofSeconds(60)))).build();
private final Logger logger = LoggerFactory.getLogger(QanaryRequestObject.class);
private final String QANARY_PIPELINE_HOST;
private final int QANARY_PIPELINE_PORT;
private VirtGraph connection;

public QanaryRepository(
@Value("${virtuoso.triplestore.endpoint}") String virtuosoEndpoint,
@Value("${virtuoso.triplestore.username}") String virtuosoUser,
@Value("${virtuoso.triplestore.password}") String virtuosoPassword,
@Value("${qanary.pipeline.host}") String qanaryHost,
@Value("${qanary.pipeline.port}") int qanaryPort
) {
this.initConnection(virtuosoEndpoint, virtuosoUser, virtuosoPassword);
this.QANARY_PIPELINE_HOST = qanaryHost;
this.QANARY_PIPELINE_PORT = qanaryPort;
}

public static QanaryResponseObject executeQanaryPipeline(QanaryRequestObject qanaryRequestObject) {
public QanaryResponseObject executeQanaryPipeline(QanaryRequestObject qanaryRequestObject) {

MultiValueMap<String, String> multiValueMap = new LinkedMultiValueMap();
multiValueMap.add("question", qanaryRequestObject.getQuestion());
multiValueMap.addAll(qanaryRequestObject.getComponentListAsMap());

QanaryResponseObject responseObject = webClient.post().uri(uriBuilder -> uriBuilder // TODO: use new endpoint for question answering
return webClient.post().uri(uriBuilder -> uriBuilder // TODO: use new endpoint for question answering
.scheme("http").host(QANARY_PIPELINE_HOST).port(QANARY_PIPELINE_PORT).path("/startquestionansweringwithtextquestion")
.queryParams(multiValueMap)
.build())
.retrieve()
.bodyToMono(QanaryResponseObject.class)
.block();

logger.info("Response Object: {}", responseObject);

return responseObject;
}

public static ResultSet selectWithResultSet(String sparql) {
logger.warn("Executing with SPARQL endpoint {}", sparqlendpoint);
QueryExecution queryExecution = connection.query(sparql);
return queryExecution.execSelect();
}

@Value("${qanary.pipeline.host}")
public void setQanaryPipelineHost(String qanaryPipelineHost) {
QANARY_PIPELINE_HOST = qanaryPipelineHost;
}

@Value("${qanary.pipeline.port}")
public void setQanaryPipelinePort(int qanaryPipelinePort) {
QANARY_PIPELINE_PORT = qanaryPipelinePort;
public ResultSet selectWithResultSet(String sparql) throws QueryException {
Query query = QueryFactory.create(sparql);
VirtuosoQueryExecution vqe = VirtuosoQueryExecutionFactory.create(query, this.connection);
ResultSetRewindable results = ResultSetFactory.makeRewindable(vqe.execSelect());
return results;
}

@Value("${sparql.endpoint}")
public void setVirtuosoEndpoint(String sparqlEndpoint) {
sparqlendpoint = sparqlEndpoint;
connection = RDFConnection.connect(sparqlEndpoint);
public void initConnection(String virtEndpoint, String virtUser, String virtPassword) {
logger.info("Init connection for Qanary repository: {}", virtEndpoint);
connection = new VirtGraph(virtEndpoint, virtUser, virtPassword);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Service;

Expand All @@ -34,13 +33,13 @@ public class AutomatedTestingService {
private final Logger logger = LoggerFactory.getLogger(AutomatedTestingService.class);
// stores the correct template for different x-shot approaches
private final Map<Integer, String> exampleCountAndTemplate = new HashMap<>() {{
put(1, "/testtemplates/oneshot");
put(2, "/testtemplates/twoshot");
put(3, "/testtemplates/threeshot");
put(1, "/prompt_templates/outputdata/oneshot");
put(2, "/prompt_templates/outputdata/twoshot");
put(3, "/prompt_templates/outputdata/threeshot");
}};


private final Random random;
@Value("${questionId.replacement}")
private String QUESTION_ID_REPLACEMENT;
@Autowired
private GenerativeExplanations generativeExplanations;
@Value("${explanations.dataset.limit}")
Expand All @@ -51,7 +50,7 @@ public class AutomatedTestingService {
private ExplanationDataService explanationDataService;

// CONSTRUCTOR(s)
public AutomatedTestingService(Environment environment) {
public AutomatedTestingService() {
this.random = new Random();
}

Expand All @@ -61,8 +60,7 @@ public String selectComponent(AnnotationType annotationType, AutomatedTest autom

// Case if example is null, e.g. when the testing data is calculated
if (example == null || !example.getUniqueComponent()) {
int selectedComponentAsInt = random.nextInt(componentsList.length);
return componentsList[selectedComponentAsInt];
return selectRandomComponentFromComponentList(componentsList);
}
// Case if component should be unique in the whole test-case
else {
Expand All @@ -76,12 +74,18 @@ public String selectComponent(AnnotationType annotationType, AutomatedTest autom
componentList.remove(rnd); // Remove visited item -> list.size()-1 -> prevent infinite loop
} while (usedComponentsInTest.contains(component));
} catch (Exception e) {
throw new RuntimeException("There is no other unique and unused component for type " + annotationType.name());
logger.error("There is no other unique and unused component for type {}, select other component...", annotationType.name());
return selectRandomComponentFromComponentList(componentsList);
}
return component;
}
}

public String selectRandomComponentFromComponentList(String[] componentsList) {
int selectedComponentAsInt = random.nextInt(componentsList.length);
return componentsList[selectedComponentAsInt];
}

public ArrayList<String> fetchUsedComponents(AutomatedTest automatedTest) {
ArrayList<String> list = new ArrayList<>();
list.add(automatedTest.getTestData().getUsedComponent()); // Adds test-data component
Expand All @@ -90,7 +94,6 @@ public ArrayList<String> fetchUsedComponents(AutomatedTest automatedTest) {
for (TestDataObject item : listExamples) { // Adds every currently known component to the list
list.add(item.getUsedComponent());
}

return list;
}

Expand Down Expand Up @@ -165,15 +168,15 @@ public TestDataObject computeSingleTestObject(AnnotationType givenAnnotationType
logger.info("Execute Qanary pipeline");
QanaryResponseObject qanaryResponse = generativeExplanations.executeQanaryPipeline(question, componentListForQanaryPipeline);
String graphURI = qanaryResponse.getOutGraph();
String questionID = qanaryResponse.getQuestion().replace("http://localhost:8080/question/stored-question__text_", "questionID:");
String questionID = qanaryResponse.getQuestion().replace(QUESTION_ID_REPLACEMENT + "/question/stored-question__text_", "questionID:");

// Create dataset
logger.info("Create dataset");
String dataset = generativeExplanations.createDataset(selectedComponent, graphURI, givenAnnotationType.name());

// Create Explanation for selected component
logger.info("Create explanation");
String explanation = generativeExplanationsService.getTemplateExplanation(graphURI, selectedComponent);
String explanation = generativeExplanationsService.getTemplateExplanation(graphURI, selectedComponent, "en");
return new TestDataObject(
givenAnnotationType,
givenAnnotationType.ordinal(),
Expand Down Expand Up @@ -231,7 +234,6 @@ public String createTestWorkflow(AutomatedTestRequestBody requestBody, boolean d
AutomatedTest test;

while (jsonArray.length() < requestBody.getRuns()) {
logger.info("CURRENT RUN: {}", jsonArray.length() + 1);
test = createTest(requestBody); // null if not successful
if (test != null) {
if (doGptCall) {
Expand All @@ -240,7 +242,7 @@ public String createTestWorkflow(AutomatedTestRequestBody requestBody, boolean d
}
JSONObject finishedTest = new JSONObject(test);
jsonArray.put(finishedTest); // Add test to Json-Array
explanationDataService.insertDataset(test);
explanationDataService.insertDataset(test, doGptCall);
} else
logger.info("Skipped run due to null-ResultSet");
}
Expand All @@ -257,8 +259,6 @@ public Integer selectComponentAsInt(AnnotationType annotationType) {

}

// TODO: don't retry already used combinations




Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,11 @@ public ExplanationDataService() {
*
* @param automatedTest Object which will be parsed to a model and inserted to the triplestore
*/
public void insertDataset(AutomatedTest automatedTest) {
public void insertDataset(AutomatedTest automatedTest, boolean isGptCall) {

String uuid = UUID.randomUUID().toString();
List<Statement> statementList = createSpecificStatementList(automatedTest, uuid);
List<Statement> statementList = createSpecificStatementList(automatedTest, uuid, isGptCall);
model.add(statementList);
// TODO: Move connection details (e.g. application.properties / ...)
VirtModel virtModel = VirtModel.openDatabaseModel("urn:aex:" + uuid, VIRTUOSO_TRIPLESTORE_ENDPOINT, VIRTUOSO_TRIPLESTORE_USERNAME, VIRTUOSO_TRIPLESTORE_PASSWORD);
virtModel.add(model); // TODO: Auslagern des VirtModel Aufrufs
virtModel.close();
Expand All @@ -98,14 +97,14 @@ public void insertDataset(AutomatedTest automatedTest) {
* @param uuid graph identifier
* @return List of Statements (= triples)
*/
public List<Statement> createSpecificStatementList(AutomatedTest automatedTest, String uuid) {
public List<Statement> createSpecificStatementList(AutomatedTest automatedTest, String uuid, boolean isGptCall) {
List<Statement> statementList = new ArrayList<>();
Resource experimentId = model.createResource(uuid);
logger.info("Experiment ID: {}", experimentId);


statementList.add(ResourceFactory.createStatement(experimentId, prompt, ResourceFactory.createPlainLiteral(automatedTest.getPrompt())));
//statementList.add(ResourceFactory.createStatement(experimentId, gptExplanation, ResourceFactory.createPlainLiteral(automatedTest.getGptExplanation())));
if (isGptCall)
statementList.add(ResourceFactory.createStatement(experimentId, gptExplanation, ResourceFactory.createPlainLiteral(automatedTest.getGptExplanation())));
statementList.add(ResourceFactory.createStatement(experimentId, testData, setUpTestObject(automatedTest.getTestData())));
statementList.add(ResourceFactory.createStatement(experimentId, exampleData, setupExampleData(automatedTest.getExampleData())));

Expand Down
Loading

0 comments on commit a54b906

Please sign in to comment.