Friday, April 27, 2012

Concurrency in Java

Concurrent programming is currently the most talked about and the most important feature of java.
It was released with java 7 and will only work with a java7 compiler.
Concurrent programming includes threading or creating a thread pool and distributing our work among the threads to complete the work faster and use less cpu time.
This is very important because by using a single thread we may require about a day to get the output for this problem but with concurrency it takes about 1 second to get the output.

PROBLEM STATEMENT:

Randomly generate one million (x, y) points, where 0.0 <= x < 1.0 and 0.0 <= y 1.0, and x and y are of type double. Then find and print out the two closest points, along with the distance between them.
Do this as quickly as possible, and print out the time required in seconds, as a floating point number.
Use ForkJoin. If you have a dual-core or quad-core computer, this should make your program about 2x or 4x faster. If you don't, using ForkJoin will slow things down slightly--but use it anyway!

My code consist of 3 parts, it is always advisable to make different classes and then merge them together so that we dont have 1 code going upto thousands of lines, it makes debugging easier.

Pointset class:
This class creates an object of a point having 2 co-ordinates x and y.
This helps us in storing and creating more than a million objects of this kind.

package jointFork;

import java.util.Comparator;

/**
 * @author Arpit Jain
 * @version April 17,2012
 * This class creates the objects for the 2 coordinates of a point.
 *
 */
public class PointSet implements Comparable<PointSet>{

double x;
double y;

public PointSet(double x, double y) {
this.x = x;
this.y = y;
}

public double getX() {
return this.x;
}

public double getY() {
return this.y;
}

@Override
public String toString(){
String output = "x = " + this.x + ", y = " + this.y; //$NON-NLS-1$ //$NON-NLS-2$
return output;
}

@Override
public int compareTo(PointSet point1) {
return (int)(this.x - point1.x);
}
/**
* Creates duplicate of the same PointSet
*/
public PointSet clone(){
PointSet cloneSet = new PointSet(x, y);
return cloneSet;
}
/**
* This function is used within Array Sort.
* It sorts the array according to only x points.
*/
public static final Comparator<PointSet> PointSetComparatorX = new Comparator<PointSet>() {
@Override
public int compare(PointSet point1, PointSet point2) {
if (point1.x < point2.x) {
return -1;
}
if (point1.x > point2.x) {
return 1;
}
return 0;
}
};
/**
* This function is used within Array Sort.
* It sorts the array according to only y points.
*/
public static final Comparator<PointSet> PointSetComparatorY = new Comparator<PointSet>() {
@Override
public int compare(PointSet point1, PointSet point2) {
if (point1.y < point2.y) {
return -1;
}
if (point1.y > point2.y) {
return 1;
}
return 0;
}
};
}

forkjoin class:
This class only creates the fork pool and given threshold values and set the starting parameters.
Its only job is to take the seed from user, calculate the time and calculate the distance between 2 closest points it receives.


package jointFork;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;

/**
 *
 * @author Arpit Jain
 * @version April 17,2012
 *
 */
public class ForkJoinPoints {

public static ForkJoinPool forkPool = new ForkJoinPool();
private static Random rand;
static PointSet[] xSorted;
public static BufferedReader br;
static final int numberOfPoints = 1000000;

/**
* Takes the seed to generate random numbers as a user input
* Sort all the generated xSorted with respect to x coordinates
* Displays the time take (in secs) to find the closest xSorted
* @param args
*/
public static void main(String[] args) {

float time;

System.out.println("Enter the Seed value:");

br = new BufferedReader(new InputStreamReader(System.in));
String inputEntry = null;
try {
inputEntry = br.readLine();
}
catch(IOException ioe) {
System.out.println("IO error trying to read your entry!"); //$NON-NLS-1$
System.exit(1);
}
catch(NumberFormatException i) {
System.out.println("invalid entry "); //$NON-NLS-1$
System.exit(1);
}
int intValue = Integer.parseInt(inputEntry);
rand =  new Random(intValue);
int numbers = 0;
xSorted = new PointSet[numberOfPoints];


while(numbers < numberOfPoints) {
double Xvalue = rand.nextDouble();
double Yvalue = rand.nextDouble();
xSorted [numbers] = new PointSet(Xvalue,Yvalue);
numbers++;
}

long start_time = System.currentTimeMillis();

Arrays.sort(xSorted,PointSet.PointSetComparatorX);


PointSet[] pairs;

pairs = forkPool.invoke(new SplitWork(xSorted,0, numberOfPoints - 1));
// print value of closest xSorted
System.out.println("The Closest Pairs are : " );
System.out.println(pairs[0] + "\n" + pairs[1]);
double dist1 = distance(pairs[0],pairs[1]);
System.out.println("Distance between the 2 closest points is : " + dist1);
time = (System.currentTimeMillis() - start_time)/1000;
System.out.println("Time taken = " + time + " seconds");
}

public static double distance(PointSet p1, PointSet p2) {
double powerX = Math.pow(p1.x - p2.x, 2);
double powerY = Math.pow(p1.y - p2.y, 2);
double finalDistance = Math.sqrt(powerX + powerY);
return finalDistance;
}
}


split work class:
This class does all the calculation and creates objects of pointset class and uses shortest distance algorithm to find the value of the closest points. if we use brute force to find closest distance between a million points then it will take almost a day to calculate it. hence we dont make use of brute force.

package jointFork;

import java.util.Arrays;
import java.util.concurrent.RecursiveTask;

/**
 * @author Arpit Jain
 * @version April 17,2012
 * The class extends the Recursive task i.e every thread is
 * given its own set of work by the compute function and 
 * work is done by recursively calling the function.
 * 
 */
class SplitWork extends RecursiveTask<PointSet[]> {
private static final long serialVersionUID = 1L;
int startIndex;
int endIndex;
private static PointSet[] xSorted;
static final int thresholdValue = 5000;
/**
* This constructor is used by to get the values of the set containing the
* points and to get the start and the end index initially when it is called
* by the fork join pool.
* @param p - Array of objects(x & y coordinates)
* @param startIndex - initial index of the size of array i.e 0
* @param endIndex - last index of the array
*/

SplitWork(PointSet[]p,int startIndex, int endIndex) {
this.startIndex = startIndex;
this.endIndex = endIndex;
SplitWork.setxSorted(p);
}
/**
* This constructor is used to call recursively
* @param startIndex- current startIndex
* @param endIndex - current End index
*/
SplitWork(int startIndex, int endIndex) {
this.startIndex = startIndex;
this.endIndex = endIndex;
}
/**
*Compute function provides each thread with its set of work 
*/
public PointSet[] compute() {
return closestPoints(startIndex,endIndex);
}

/**
*  Calculating the distance between 2 points
* @param p1 -First point
* @param p2 - second point
* @return -distance between them
*/
public static double distance(PointSet p1, PointSet p2) {
double powerX = Math.pow(p1.x - p2.x, 2);
double powerY = Math.pow(p1.y - p2.y, 2);
double finalDistance = Math.sqrt(powerX + powerY);
return finalDistance;
}
/**
* This function computes the closest pairs of xSorted by dividing 
* the array of xSorted recursively
* @param startIndex - starting index of the array of xSorted
* @param endIndex - last index of the array of xSorted
* @return - the closest pairs of xSorted
*/
static PointSet[] closestPoints(final int startIndex, final int endIndex) {
PointSet[]  pairs = new PointSet[2];
if(endIndex - startIndex == 0) {
return null;
}

else if(endIndex - startIndex == 1) {
pairs[0] = getxSorted()[startIndex];
pairs[1] = getxSorted()[endIndex];
return pairs;
}

// else if(endIndex- startIndex+1 == 3 ) {
//
// double minimumDistance= -1;
// for(int i = 0; i<3;i++) {
// for(int j = i+1;j<3;j++) {
// double dist = distance(xSorted[i],xSorted[j]);
// if(i == 0 && j == i+1) 
// minimumDistance = dist;
//
// if(minimumDistance > dist) {
// minimumDistance = dist;
// pairs[0] = xSorted[i];
// pairs[1] = xSorted[j];
// }
// }
// }
// return pairs;
// }
else if(endIndex - startIndex == 2) {
           int number = getClosestPoints(getxSorted()[startIndex],getxSorted()[startIndex+1],getxSorted()[endIndex]);
         if(number == 0) {
         return null;
         }
         else if(number == 1) {
          pairs[0] = getxSorted()[startIndex];
          pairs[1] = getxSorted()[startIndex + 1];
         }
         else if(number == 2) {
          pairs[0] = getxSorted()[startIndex+1];
          pairs[1] = getxSorted()[endIndex];
         }
         else if(number == 3) {
          pairs[0] = getxSorted()[startIndex];
  pairs[1] = getxSorted()[endIndex];
         }
           return pairs;
}

final int midIndex = (startIndex + endIndex) / 2;
int leftValue = midIndex - startIndex +1;
int rightValue = endIndex - midIndex;
PointSet[] leftSet = null;
PointSet[] rightSet = null;
SplitWork leftSide = null, rightSide = null;

if (leftValue < thresholdValue) {
leftSet = closestPoints(startIndex, midIndex);

} else {
leftSide = new SplitWork(startIndex, midIndex);
leftSide.fork();
}

if (rightValue < thresholdValue) {
rightSet = new PointSet[2];
rightSet = closestPoints(midIndex + 1, endIndex);


} else {
rightSide = new SplitWork(midIndex + 1, endIndex);
rightSide.fork();
}

if (leftSide != null) {
leftSet = leftSide.join();
}
if (rightSide != null) {
rightSet = rightSide.join();
}

//finding the minimum distance on left and right
double minLeft = distance(leftSet[0], leftSet[1]);
double minRight = distance(rightSet[0], rightSet[1]);
double minRL;
minRL = Math.min(minLeft, minRight);

if (minRL == minLeft) {
pairs[0] = leftSet[0];
pairs[1] = leftSet[1];
} else {
pairs[0] = rightSet[0];
pairs[1] = rightSet[1];

}

double minDist= minRL;

int leftMiddle= midIndex;
int rightMiddle = midIndex;
double middleX = getxSorted()[midIndex].x;
double middleLeftX = middleX - minRL;
double middleRightX = middleX + minRL;

while (leftMiddle> startIndex && getxSorted()[leftMiddle- 1].x > middleLeftX) {
leftMiddle--;
}

while (rightMiddle < endIndex && getxSorted()[rightMiddle + 1].x < middleRightX) {
rightMiddle++;
}

/*
*Array of elements which might have the shortest 
*distance but are divided by the midpoint.
*
*/
PointSet []Ystrip = new PointSet[rightMiddle - leftMiddle+ 1];

for (int index = 0; index < Ystrip.length; index++) {
Ystrip[index] = getxSorted()[index + leftMiddle].clone();
}

Arrays.sort(Ystrip,PointSet.PointSetComparatorY);


// Calculating the distance of all the xSorted in Ystrip with each other
for (int i = 0; i < Ystrip.length - 1; i++) {
for (int j = i+1; j < Ystrip.length && (Ystrip[j].getY() -  Ystrip[i].getY() < minRL); j++) {
double d = distance(Ystrip[i],Ystrip[j]);
if (d < minDist) {
minDist= d;
pairs[0] = Ystrip[i];
pairs[1] = Ystrip[j];
}
}
}

return pairs;
}
/**
* For the array of size 3 we calculate the smallest distance by using 
* brute force.
* @param p1- first point
* @param p2- second point
* @param p3- third point
* @return- the point whose distance is the shortest
*/
public static  int getClosestPoints(PointSet p1, PointSet p2, PointSet p3) {
       double distance1 = distance(p1,p2);
       double distance2  = distance(p2,p3);
       double distance3 = distance(p1,p3);
       if(distance1 <= distance2 && distance1 <= distance3) {
        return 1;
       }
       else if(distance2 <= distance1 && distance2 <= distance3) {
        return 2;
       }
       else if(distance3 <= distance1 && distance3<=distance2) {
        return 3;
       }
       return 0;
   }

public static PointSet[] getxSorted() {
return xSorted;
}

public static void setxSorted(PointSet[] xSorted) {
SplitWork.xSorted = xSorted;
}
}


No comments:

Post a Comment