Monday, March 26, 2012

Multithreading in java example

Hi there!

Here is a new small piece of code to understand how to multithread a simple task.

This is a multithreaded counter. This is again voluntarily simplistic to help understanding how to create several threads.

In order to understand, you'll need basic knowledge in JAVA core programming.

Here is the code - explanations at the end:
class Shared {
  double result=0.0;
  long startIndex, endIndex;
  static int nbThreads;
  Shared(long startIndex, long endIndex){
    this.startIndex=startIndex;
    this.endIndex=endIndex;
  }
  double getResult(){
    return result;
  }
  void computeResult(long index){
  //coputing formula should be here
  result=result+1;
  }
}

// Create a new thread.
class NewThread implements Runnable {
  Shared s;
 
  NewThread(Shared S){
  //Init Shared data
  this.s=S;
  }
  
  public void run() {
    for (long i=s.startIndex;i<s.endIndex;i++){
        s.computeResult(i);
    }
  }
}


class DhaThreadsTest {
   
  public static void main(String args[]) {
    Shared S;
    long i=0;
    int k;
    double result=0.0;
    long startTime, endTime, duration, totalSize;
     
    totalSize=Long.parseLong(args[0]);
    Shared.nbThreads=Integer.parseInt(args[1]);
   
    System.out.println("Size:"+ totalSize);
    System.out.println("NbThreads:"+ Shared.nbThreads);
     
    Thread threads[]  = new Thread[Shared.nbThreads]; 
    Shared shared[] = new Shared[Shared.nbThreads];
   
    //Starting timer
    startTime=System.nanoTime();
   
    for(k =0 ; k < Shared.nbThreads; k++) {
      if (k<Shared.nbThreads-1){
        shared[k] = new Shared(k*(totalSize/Shared.nbThreads),(k+1)*(totalSize/Shared.nbThreads));
        //System.out.println("Start index : " + (k*(totalSize/Shared.nbThreads)) +" End Index:" +((k+1)*(totalSize/Shared.nbThreads)));
      }
      else {
        shared[k] = new Shared(k*(totalSize/Shared.nbThreads), totalSize );
        //System.out.println("Start index : " + k*(totalSize/Shared.nbThreads) +" End Index:" + totalSize );
      }
      threads[k]  = new Thread(new NewThread(shared[k]));
      threads[k].start();   
    }

    for(k =0 ; k < Shared.nbThreads; k++) {
        //System.out.println("Waiting Thread: " + threads[k].getName()+" to complete");
        try{
          threads[k].join();
        }
        catch(InterruptedException e) {
        System.out.println("Thread: " + threads[k].getName()+" interrupted");
        }
    }

    for(k =0 ; k < Shared.nbThreads; k++) {
        //System.out.println("Result: " + result +" thread:" + k);
        result=result+shared[k].getResult();
    }

    //Stopping the timer
    endTime=System.nanoTime();
   
    //Calculate the duration
    duration= endTime-startTime;
   
    //Displaying the statistics
    System.out.println("Start at :              " + startTime);
    System.out.println("End at :                " + endTime);
    System.out.println("Duration (ms):          " + (float)duration/1000000);
    System.out.println("Duration/cell (ns/cell):" + (float)duration/(totalSize));

     
    System.out.println("Result:" + result );
    System.out.println("Exiting Main Thread");   
  }
}

This code should be copied in a plain text file named : DhaThreadsTest.java
To compile it : javac  DhaThreadsTest.java
To run it: java DhaThreadsTest 5000000000 3
With this command line, the program counts until 5000000000 using 3 counting threads.

Here is the output on my 4core PC under win7. You should get much better results on a server or/and with a lighter OS and you should multithread more if you have more cores.
G:\dev\java>java DhaThreadsTest 5000000000 3
Size:5000000000
NbThreads:3
Start at :              57512140612108
End at :                57516759505571
Duration (ms):          4618.8936
Duration/cell (ns/cell):0.92377865
Result:5.0E9
Exiting Main Thread
Explanations on the code:

Overview :
When running: java DhaThreadsTest size nb_threads

This program creates nb_threads computing Shared objects and nb_threads thread objects
    Thread threads[]  = new Thread[Shared.nbThreads]; 
    Shared shared[] = new Shared[Shared.nbThreads];
Then, it populates each of the Shared objects and start the corresponding thread
    for(k =0 ; k < Shared.nbThreads; k++) {
      if (k<Shared.nbThreads-1){
        shared[k] = new Shared(k*(totalSize/Shared.nbThreads),(k+1)*(totalSize/Shared.nbThreads));
        //System.out.println("Start index : " + (k*(totalSize/Shared.nbThreads)) +" End Index:" +((k+1)*(totalSize/Shared.nbThreads)));
      }
      else {
        shared[k] = new Shared(k*(totalSize/Shared.nbThreads), totalSize );
        //System.out.println("Start index : " + k*(totalSize/Shared.nbThreads) +" End Index:" + totalSize );
      }
      threads[k]  = new Thread(new NewThread(k,shared[k]));
      threads[k].start();   
    }

When doing this, we specify that the counting is equally shared in between the different threads.
Each thread will count from (k*(totalSize/Shared.nbThreads) to (k+1)*(totalSize/Shared.nbThreads) except for the last thread that takes the remainder, and so counts from k*(totalSize/Shared.nbThreads) to totalSize.

Notice the thread[k].start() line that explicitely start each thread.

Then, the main thread waits until all the threads are done.
    for(k =0 ; k < Shared.nbThreads; k++) {
        //System.out.println("Waiting Thread: " + threads[k].getName()+" to complete");
        try{
          threads[k].join();
        }
        catch(InterruptedException e) {
        System.out.println("Thread: " + threads[k].getName()+" interrupted");
        }
    }
Note that, we loop only once on all the threads. When thread 1 is over, we wait for thread 2 to finish, then thread 3 and so on .. It could be happen that thread 3 is done before 1 is over. This is not a problem as we wait that every thread finishes before we go to the next step.

Then, when all threads are done, we consolidate the results in each Shared object in result variable.
    for(k =0 ; k < Shared.nbThreads; k++) {
        //System.out.println("Result: " + result +" thread:" + k);
        result=result+shared[k].getResult();
    }


Then we finish displaying timing statistics.


Let's drill down into the program structure:
Shared objects contains
  • the starting and ending index for the computation for that instance of Shared
  • the result of the computation for that instance of Shared
  • the number of computing threads for the program (which is a static variable as it is global variable). Warning: for each instance of Shared there is 1 thread created.
  • getResult method that returns the comptation result
  • computeResult method that computes the result for the specified index. Here we don't use the index as we are just counting (adding 1 to result) but could use a fancier computation formula like for instance result=result +1.0/(index*index) if we wanted to calculate the sum of inverse squares.
NewThread class implements Runnable in order to run a child thread. This is one of the two methods to create and run a thread in java.
Its main method is run that loops from start index until the end index in the corresponding instance of Shared and call each time the compute method. This run method is defined on the runnable interface.

Remarks:
Note that each thread has its own, shared instance. This means that each thread cannot update a variable while another is reading or updating it. This means that there is no concurrency issue.

Note also that we wait for all the threads to be done before we consolidate the results, but this is no big concern as most of calculating time is spent computing and not consolidating.

Performance improve linearily as we increase the number of threads until we have too many threads and it gets counterproductive for the system to switch between threads.

Note that the size variable should be big enough in order to spend time computing and not just creating threads.

It is funny to see your machine (no matter how powerfull it is) crunching data and reaching 100% utilisation (perf manager in windows or vmstat / top under unix) just doing such a simple task.

Concurrency problems are quite difficult to solve and we haven't met them due to the simplicity of the our study. In more complex this is hard to lock data efficiently in order to avoid useless locking (making the application slower) and also to have a threadsafe program.

Finally, before you try to multithread a program, as concurrency problem can get hard to manage and debug, it is recommended to first see if you cannot optimize your application in an algorithmic point of view first (avoid useless calculation, avoid hitting the database etc ...)

Have fun!

No comments:

Post a Comment