给定一个包含 n 个整数的数组 nums,判断 nums 中是否存在三个元素 a,b,c ,使得 a + b + c = 0 ?找出所有满足条件且不重复的三元组。
注意:答案中不可以包含重复的三元组。
例如, 给定数组 nums = [-1, 0, 1, 2, -1, -4], 满足要求的三元组集合为: [ [-1, 0, 1], [-1, -1, 2] ]
分析:本题算法比较简单,主要考点是算法的时间复杂度
思路1:三次遍历,选取所有的[a,b,c]组合判断是否满足条件,时间复杂度为O(n^3)。
代码:
class Solution(object): def threeSum(self, nums): """ :type nums: List[int] :rtype: List[List[int]] """ result=[] for i in range(len(nums)): for j in range(i+1,len(nums)): for k in range (j+1,len(nums)): if nums[i]+nums[j]+nums[k]==0: temp=[nums[i],nums[j],nums[k]] temp.sort() if temp not in result: result.append(temp) # print result return result
结果:
思路2:以0为分界点,将数组分成大于0和小于0以及0三部分。这样可能出现的解答为[a,b,c],[a,0,b],[0,0,0]三大类,数组长度变短后遍历时所需时间也会变少,时间复杂度小于O(n^3)(具体是多少不会推导。。)
代码:
class Solution(object): def threeSum(self, nums): """ :type nums: List[int] :rtype: List[List[int]] """ pivots=[x for x in nums if x ==0] less =[x for x in nums if x <0] more =[x for x in nums if x >0] result=[] self.len_pivots_zero(result,less,more) if len(pivots)>=3: self.len_pivots_one(result,less,more) self.len_pivots_threemore(result,less,more) elif 3>len(pivots)>0: self.len_pivots_one(result,less,more) # end=time.time() # print end-starttime,result return result def len_pivots_zero(self,result,less,more): temp=[] for i in range(len(less)): for j in range(i+1,len(less)): k=0-less[i]-less[j] if k in more: temp=[less[i],less[j],k] temp.sort() if temp not in result: result.append(temp) for i in range(len(more)): for j in range(i+1,len(more)): k=0-more[i]-more[j] if k in less: temp=[more[i],more[j],k] temp.sort() if temp not in result: result.append(temp) def len_pivots_one(self,result,less,more): temp=[] for i in less: j=0-i if j in more: temp=[i,0,j] temp.sort() if temp not in result: result.append([i,0,j]) def len_pivots_threemore(self,result,less,more): result.append([0,0,0])
结果:
通过的测试用例明显增多,说明思路2算法耗时缩短,但是还是不满足题目要求。
思路3:先将原数组排序后,第一个点为a,然后下一个点为b,最后一个点为c,依次增大b减小c看是否满足要求。
代码:
class Solution(object): def threeSum(self, nums): """ :type nums: List[int] :rtype: List[List[int]] """ starttime=time.time() ans = [] nums.sort() for i in range(len(nums)-2): if i == 0 or nums[i] > nums[i-1]: left = i+1 right = len(nums)-1 while left < right: ident = nums[left] + nums[right] + nums[i] if ident == 0: ans.append([nums[i], nums[left], nums[right]]) left += 1; right -= 1 while left < right and nums[left] == nums[left-1]: # skip duplicates left += 1 while left < right and nums[right] == nums[right+1]: right -= 1 elif ident < 0: left += 1 else: right -= 1 end=time.time() print end-starttime,ans return ans
结果: