diff --git a/divide_and_conquer/closest_pair_of_points.py b/divide_and_conquer/closest_pair_of_points.py
index cc5be428..ee06d270 100644
--- a/divide_and_conquer/closest_pair_of_points.py
+++ b/divide_and_conquer/closest_pair_of_points.py
@@ -1,27 +1,27 @@
 """
-The algorithm finds distance btw closest pair of points in the given n points.
+The algorithm finds distance between closest pair of points 
+in the given n points.
 Approach used -> Divide and conquer 
-The points are sorted based on Xco-ords 
-& by applying divide and conquer approach, 
+The points are sorted based on Xco-ords and 
+then based on Yco-ords separately.
+And by applying divide and conquer approach, 
 minimum distance is obtained recursively.
 
->> closest points lie on different sides of partition
+>> Closest points can lie on different sides of partition.
 This case handled by forming a strip of points 
 whose Xco-ords distance is less than closest_pair_dis
-from mid-point's Xco-ords.
+from mid-point's Xco-ords. Points sorted based on Yco-ords 
+are used in this step to reduce sorting time.
 Closest pair distance is found in the strip of points. (closest_in_strip)
 
 min(closest_pair_dis, closest_in_strip) would be the final answer.
  
-Time complexity: O(n * (logn)^2)
+Time complexity: O(n * log n)
 """
 
 
-import math 
-
-
 def euclidean_distance_sqr(point1, point2):
-    return pow(point1[0] - point2[0], 2) + pow(point1[1] - point2[1], 2)
+    return (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
 
 
 def column_based_sort(array, column = 0):
@@ -66,7 +66,7 @@ def dis_between_closest_in_strip(points, points_counts, min_dis = float("inf")):
     return min_dis
 
 
-def closest_pair_of_points_sqr(points, points_counts):
+def closest_pair_of_points_sqr(points_sorted_on_x, points_sorted_on_y, points_counts):
     """ divide and conquer approach
 
     Parameters : 
@@ -79,12 +79,16 @@ def closest_pair_of_points_sqr(points, points_counts):
 
     # base case
     if points_counts <= 3:
-        return dis_between_closest_pair(points, points_counts)
+        return dis_between_closest_pair(points_sorted_on_x, points_counts)
     
     # recursion
     mid = points_counts//2
-    closest_in_left = closest_pair_of_points(points[:mid], mid)
-    closest_in_right = closest_pair_of_points(points[mid:], points_counts - mid)
+    closest_in_left = closest_pair_of_points_sqr(points_sorted_on_x, 
+                                                 points_sorted_on_y[:mid], 
+                                                 mid)
+    closest_in_right = closest_pair_of_points_sqr(points_sorted_on_y, 
+                                                  points_sorted_on_y[mid:], 
+                                                  points_counts - mid)
     closest_pair_dis = min(closest_in_left, closest_in_right)
     
     """ cross_strip contains the points, whose Xcoords are at a 
@@ -92,22 +96,25 @@ def closest_pair_of_points_sqr(points, points_counts):
     """
 
     cross_strip = []
-    for point in points:
-        if abs(point[0] - points[mid][0]) < closest_pair_dis:
+    for point in points_sorted_on_x:
+        if abs(point[0] - points_sorted_on_x[mid][0]) < closest_pair_dis:
             cross_strip.append(point)
 
-    cross_strip = column_based_sort(cross_strip, 1)
     closest_in_strip = dis_between_closest_in_strip(cross_strip, 
                      len(cross_strip), closest_pair_dis)
     return min(closest_pair_dis, closest_in_strip)
 
     
 def closest_pair_of_points(points, points_counts):
-    return math.sqrt(closest_pair_of_points_sqr(points, points_counts))
+    points_sorted_on_x = column_based_sort(points, column = 0)
+    points_sorted_on_y = column_based_sort(points, column = 1)
+    return (closest_pair_of_points_sqr(points_sorted_on_x, 
+                                       points_sorted_on_y, 
+                                       points_counts)) ** 0.5
 
 
-points = [(2, 3), (12, 30), (40, 50), (5, 1), (12, 10), (0, 2), (5, 6), (1, 2)]
-points = column_based_sort(points)
-print("Distance:", closest_pair_of_points(points, len(points)))
+if __name__ == "__main__":
+    points = [(2, 3), (12, 30), (40, 50), (5, 1), (12, 10), (3, 4)] 
+    print("Distance:", closest_pair_of_points(points, len(points)))