Skip to content

Commit

Permalink
Too many changes to count
Browse files Browse the repository at this point in the history
  • Loading branch information
vineet1992 committed Jun 12, 2019
1 parent 4687529 commit e798e3e
Show file tree
Hide file tree
Showing 13 changed files with 650 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ else if(args[index].equals("-rv"))

DataSet predictionData = data;

if(steps)
if(steps && !prefDiv)
{
System.out.print("Running STEPS...");
STEPS s = new STEPS(temp,initLambdas,g,samps);
Expand All @@ -249,7 +249,7 @@ else if(prefDiv)
printer.println(test);
printer.flush();
printer.close();
String cmd = "java -jar PrefDiv.jar -data temp_train.txt -dataTest temp_test.txt -loocv " + " -t " + target + " -name temp -numSelect " + numSelect + " -keep ";
String cmd = "java -jar PrefDiv.jar -data temp_train.txt -dataTest temp_test.txt -loocv " + "-t " + target + " -name temp_" + i + " -numSelect " + numSelect + " -keep ";

for(String s:toRemove)
{
Expand Down Expand Up @@ -299,12 +299,20 @@ else if(ctype== RunPrefDiv.ClusterType.MEDIAN)
System.err.println(s);
}

out = GraphUtils.loadGraphTxt(new File("temp/graph.txt"));
if(steps)
{
out = GraphUtils.loadGraphTxt(new File("temp_" + i + "/GRAPH_temp_train.txt"));
}
else
{
out = GraphUtils.loadGraphTxt(new File("temp_" + i + "/piMGM/Graph.txt"));
}
/***Construct training and testing set by getting list of genes and clusters from files***/

DataSet train = MixedUtils.loadDataSet2("temp/summarized.txt");
test = MixedUtils.loadDataSet2("temp/summarized_test.txt");
DataSet full = DataUtils.concatenate(train,test);
DataSet train = MixedUtils.loadDataSet2("temp_" + i + "/OUT_temp_train.txt");
test = MixedUtils.loadTestData("temp_" + i + "/summarized_test.txt",3,train);
DataSet full = train.copy();
MixedUtils.concatenateTo(test,full);

predictionData = full;
temp = train;
Expand All @@ -330,22 +338,26 @@ else if(ctype== RunPrefDiv.ClusterType.MEDIAN)
Graph cpcOut = cpc.search();


int testIndex = i;

if(prefDiv)
testIndex = predictionData.getNumRows()-1;

/***piMGM Predictions***/

List<Node> neighbors = out.getAdjacentNodes(out.getNode(target));
System.out.println("Adjacent To Target: " + neighbors);
double pred = getRegressionResult(neighbors,temp,predictionData,i,target);
double pred = getRegressionResult(neighbors,temp,predictionData,testIndex,target);
double real = -1;



if(predictionData.getVariable(target)instanceof DiscreteVariable)
{
real = predictionData.getInt(i,predictionData.getColumn(predictionData.getVariable(target)));
real = predictionData.getInt(testIndex,predictionData.getColumn(predictionData.getVariable(target)));
}else
{
real = predictionData.getDouble(i,predictionData.getColumn(predictionData.getVariable(target)));
real = predictionData.getDouble(testIndex,predictionData.getColumn(predictionData.getVariable(target)));
}
System.out.println("piMGM Prediction: " + pred);
pi1.print(i + "\t" + pred +"\t");
Expand All @@ -356,7 +368,7 @@ else if(ctype== RunPrefDiv.ClusterType.MEDIAN)
if(neighbors.size()==0) {
neighbors = out.getAdjacentNodes(out.getNode(target));
}
pred = getRegressionResult(neighbors,temp,predictionData,i,target);
pred = getRegressionResult(neighbors,temp,predictionData,testIndex,target);
System.out.println("CPC-MB Neighbors: " + neighbors);
System.out.println("CPC-MB Prediction: " + pred);

Expand All @@ -371,7 +383,7 @@ else if(ctype== RunPrefDiv.ClusterType.MEDIAN)
neighbors = cpcOut.getAdjacentNodes(cpcOut.getNode(target));
if(neighbors.size()==0)
neighbors = out.getAdjacentNodes(out.getNode(target));
pred = getRegressionResult(neighbors,temp,predictionData,i,target);
pred = getRegressionResult(neighbors,temp,predictionData,testIndex,target);
System.out.println("CPC Neighbors: " + neighbors);
System.out.println("CPC Prediction: " + pred);

Expand Down
142 changes: 138 additions & 4 deletions tetrad-lib/src/main/java/edu/pitt/csb/Pref_Div/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,106 @@ protected void compute(){



/***
* THIS METHOD RUNS IN PARALLEL
* @param items List of Genes, only the symbol needs to be filled in
* @param d DataSet d, a dataset on which to compute correlations
* return A float [] with all of the gene-gene correlations
***/
public static float [] computeAllAssociationsPar(final ArrayList<Gene> items, final DataSet d)
{

/***Construct mapping between indices and nodes***/
final HashMap<Integer,Integer> mapping = new HashMap<Integer,Integer>();
final HashMap<Integer,Integer> categories = new HashMap<Integer,Integer>();

/***Ensure correct mapping from genes to nodes in data***/
for(int i = 0; i < items.size();i++)
{
Node temp = d.getVariable(items.get(i).symbol);
mapping.put(i,d.getColumn(temp));
if(temp instanceof DiscreteVariable)
{
DiscreteVariable dv = (DiscreteVariable)temp;
categories.put(i,dv.getNumCategories());
}
}

/***Rows are variable, columns are samples***/
final double [][] datArray = d.getDoubleData().transpose().toArray();


int total = (items.size()*(items.size()-1))/2;
final float [] corrs = new float[total];

final ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();

class StabilityAction extends RecursiveAction {
private int chunk;
private int from;
private int to;

public StabilityAction(int chunk, int from, int to){
this.chunk = chunk;
this.from = from;
this.to = to;
}

@Override
protected void compute(){
if (to - from <= chunk) {
for (int s = from; s < to; s++) {
//RadiiWP is the columns, RadiiNP is the rows
double[] one = datArray[mapping.get(s)];

int index = Functions.getIndex(s,s+1,items.size());
for(int i = s+1; i < items.size();i++)
{
double [] two = datArray[mapping.get(i)];

if(categories.get(s)!=null || categories.get(i)!=null)
{
//corrs[index] = (float)mixedMI(two,one,categories.get(s));
IndTestMultinomialAJ ind = new IndTestMultinomialAJ(d,0.05,true);
ind.isIndependent(d.getVariable(mapping.get(i)),d.getVariable(mapping.get(s)));
corrs[index] = (float)(1-ind.getPValue());
}
else
{
corrs[index] = (float)Math.abs(StatUtils.correlation(one,two));
}


index++;
}
}

return;
} else {
List<StabilityAction> tasks = new ArrayList<>();

final int mid = (to + from) / 2;

tasks.add(new StabilityAction(chunk, from, mid));
tasks.add(new StabilityAction(chunk, mid, to));

invokeAll(tasks);

return;
}
}
}

final int chunk = d.getNumColumns()/Runtime.getRuntime().availableProcessors();
StabilityAction sa = new StabilityAction(chunk,0, items.size());
pool.invoke(sa);


return corrs;
}




/***
*
Expand Down Expand Up @@ -452,6 +552,7 @@ private static double[] getPartialCorr(TetradMatrix c, DataSet d, int x, int y)
int total = (items.size()*(items.size()-1))/2;
float [] corrs = new float[total];
IndTestCorrelationT ind = new IndTestCorrelationT(d,0.05);
IndTestMultinomialAJ indDisc = new IndTestMultinomialAJ(d,0.05);
for(int i = 0; i < items.size();i++)
{
Node one = nodes[i];
Expand All @@ -465,6 +566,16 @@ private static double[] getPartialCorr(TetradMatrix c, DataSet d, int x, int y)
int y = mapping.get(j);
corrs[index] = (float)(-1*c.get(x,y)/(c.get(x,x)*c.get(y,y)));
}
else if(d.isMixed())
{
indDisc.isIndependent(one,two);
if(ind.getPValue()>threshold && !withPrior[index])
corrs[index] = 0;
else if(ind.getPValue()>thresholdWP && withPrior[index])
corrs[index] = 0;
else
corrs[index] = (float) (1-ind.getPValue());
}
else {
ind.isIndependent(one,two);
if(ind.getPValue()>threshold && !withPrior[index])
Expand Down Expand Up @@ -562,7 +673,15 @@ public static ArrayList<Gene> computeAllIntensities(ArrayList<Gene> g1, double a
c = new CovarianceMatrix(data).getMatrix().ginverse();
}

IndTestCorrelationT ind = new IndTestCorrelationT(data,0.05);
IndependenceTest ind;
if(data.isMixed())
{
ind = new IndTestMultinomialAJ(data,0.05,true);
}
else
{
ind = new IndTestCorrelationT(data,0.05);
}
float [] corrs = new float[g1.size()];

/***For each gene in the gene list, compute correlation***/
Expand All @@ -586,7 +705,10 @@ else if(cont) //Use Correlation
corrs[i] = (float) Math.abs(StatUtils.correlation(temp[y], temp[data.getColumn(data.getVariable(g1.get(i).symbol))]));
}
else //Use Mutual Information
corrs[i] = (float) mixedMI(temp[data.getColumn(data.getVariable(i))],temp[y],numCats);
{
ind.isIndependent(data.getVariable(i),one);
corrs[i] = (float)(1-ind.getPValue());
}
}

/****Optional: Normalize the correlations to mean 0 Gaussian Dist***/
Expand Down Expand Up @@ -694,7 +816,16 @@ public static ArrayList<Gene> computeAllIntensities(ArrayList<Gene> g1, double a
c = new CovarianceMatrix(data).getMatrix();
c = c.ginverse();
}
IndTestCorrelationT ind = new IndTestCorrelationT(data,0.05);

IndependenceTest ind;
if(data.isMixed())
{
ind = new IndTestMultinomialAJ(data,0.05);
}
else
{
ind = new IndTestCorrelationT(data,0.05);
}
float [] corrs = new float[g1.size()];
for(int i = 0; i < g1.size();i++) {
Node two = data.getVariable(g1.get(i).symbol);
Expand All @@ -714,7 +845,10 @@ else if(ind.getPValue()>thresholdWP && withPriors[i])
corrs[i] = (float) Math.abs(StatUtils.correlation(temp[y], temp[data.getColumn(data.getVariable(g1.get(i).symbol))]));
}
else //Use Mutual Information
corrs[i] = (float) mixedMI(temp[data.getColumn(data.getVariable(i))],temp[y],numCats);
{
ind.isIndependent(one,two);
corrs[i] = (float) (1-ind.getPValue());
}
}
for (int i = 0; i < g1.size(); i++) {

Expand Down
46 changes: 41 additions & 5 deletions tetrad-lib/src/main/java/edu/pitt/csb/Pref_Div/PiPrefDiv4.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ public Map<Gene,List<Gene>> getLastCluster()
public double [] getAdjustP(){return adjustP;}
public void computePValue(boolean pvals){computeP = pvals;}
public void profile(){profiling = true;}
public DataSet getData(){return data;}

public double [] getNormTao()
{
if(normTao==null)
Expand Down Expand Up @@ -212,7 +214,11 @@ private double[] constrictRange(double[]init, DataSet data,boolean threshold)


float [] corrs;
if(parallel)
if(curr.isMixed())
{
corrs = Functions.computeAllAssociationsPar(temp,curr);
}
else if(parallel)
{
corrs = Functions.computeAllCorrelationsPar(temp,curr);
}else
Expand Down Expand Up @@ -326,12 +332,21 @@ public ArrayList<Gene> selectGenes(double radiiNP)


long time = System.nanoTime();


/***Add correlation to the target variable in the intensity value and fold change fields of each gene***/
ArrayList<Gene> meanGenes = Functions.computeAllIntensities(temp,1,data,target,partialCorrs,false,false,1);


/***Order of correlation is based upon the order of the genes in the array list (needs to be the same as the data order)***/
float [] meanDis = Functions.computeAllCorrelations(meanGenes,data,partialCorrs,false,false,1);
float [] meanDis;
if(data.isMixed())
{
meanDis = Functions.computeAllAssociationsPar(meanGenes,data);
}else
{
meanDis = Functions.computeAllCorrelations(meanGenes,data,partialCorrs,false,false,1);
}

if(profiling)
{
Expand Down Expand Up @@ -413,7 +428,15 @@ public ArrayList<Gene> selectGenes(double radiiNP, double radiiWP, String [] dFi


/***Compute correlation for each pair of genes***/
float [] meanDis = Functions.computeAllCorrelations(meanGenes,data,false,1,1,dPriorsWT);
float [] meanDis;
if(data.isMixed())
{
meanDis = Functions.computeAllAssociationsPar(meanGenes,data);
}else
{
meanDis = Functions.computeAllCorrelations(meanGenes,data,false,1,1,dPriorsWT);

}


/***Shuffle and sort the list of genes based upon intensity value***/
Expand Down Expand Up @@ -1009,7 +1032,15 @@ private int[] getCounts(int i, boolean needCorrs)
{
ArrayList<Gene> temp = createGenes(data,target,true);
DataSet currData = data.subsetRows(subsamples[k]);
float [] corrs = Functions.computeAllCorrelations(temp,currData,partialCorrs,false,false,1);
float [] corrs;
if(currData.isMixed())
{
corrs = Functions.computeAllAssociationsPar(temp,currData);
}
else
{
corrs = Functions.computeAllCorrelations(temp,currData,partialCorrs,false,false,1);
}
addGeneConnections(corrs,curr,initRadii[i]);

}
Expand Down Expand Up @@ -1494,7 +1525,12 @@ private void getPhi(float[] in, int numSubs)
float [] corrs;
long time = System.nanoTime();

if(parallel)

if(currData.isMixed())
{
corrs = Functions.computeAllAssociationsPar(temp,currData);
}
else if(parallel)
{
corrs = Functions.computeAllCorrelationsPar(temp,currData);
}
Expand Down
Loading

0 comments on commit e798e3e

Please sign in to comment.