数据结构与算法:回溯(下):子集相关力扣题(78.子集、90.子集Ⅱ、491.非递减子序列)、排列相关力扣题(46.全排列、47.全排列Ⅱ)、棋盘问题(51.N皇后、37.解数独)

1.子集相关力扣题

78.子集

class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
    	# 注意一开始要放入空子集
        ans = [[]]
        def backtracking(cur:int, path:List[int]) -> None:
            if len(path) == len(nums):
                return
            for i in range(cur, len(nums)):
                path.append(nums[i])
                # 区别只是在每处理一个节点的时候,就需要放入结果集
                ans.append(path.copy())
                backtracking(i+1, path)
                path.pop()
            return
        backtracking(0,[])
        return ans

效率:0ms,击败100.00%

90.子集Ⅱ

class Solution:
    def subsetsWithDup(self, nums: List[int]) -> List[List[int]]:
        ans = [[]]
        nums.sort()
        def backtracking(cur:int, path:List[int]) -> None:
            if len(path) == len(nums):
                return
            for i in range(cur,len(nums)):
                if i>cur and nums[i]==nums[i-1]:
                    # 跳过重复元素
                    continue
                path.append(nums[i])
                ans.append(path.copy())
                backtracking(i+1, path)
                path.pop()
            return
        backtracking(0,[])
        return ans

效率:0ms,击败100.00%

491.非递减子序列

给你一个整数数组 nums ,找出并返回所有该数组中不同的递增子序列,递增子序列中 至少有两个元素 。你可以按 任意顺序 返回答案。

数组中可能含有重复元素,如出现两个整数相等,也视作递增序列的一种特殊情况。

示例 1:

输入:nums = [4,6,7,7]
输出:[[4,6],[4,6,7],[4,6,7,7],[4,7],[4,7,7],[6,7],[6,7,7],[7,7]]

示例 2:

输入:nums = [4,4,3,2,1]
输出:[[4,4]]

提示:

1 <= nums.length <= 15
-100 <= nums[i] <= 100

初版代码

一开始我的代码如下:

class Solution:
    def findSubsequences(self, nums: List[int]) -> List[List[int]]:
        if len(nums)<2:
            return []
        ans = []
        def backtracking(cur:int, path:List[int]) -> None:
        	"""
        	因为还是在一个集合中挑选数。
        	所以需要用cur来记录该集合中已经遍历到哪了,防止重复
        	"""
            if cur==len(nums):
                return
            for i in range(cur, len(nums)):
			"""
	 		因为数组中可能含有重复元素,而答案要求去重。
	 		所以数组遍历时,遇到相同的数就跳过。
	 		但注意,因为答案要求的子序列是不能改变其在数组的先后顺序的。
	 		所以在这我们不应在开始给数组排序。
			"""	 		
                if i>cur and nums[i]==nums[i-1]:
                    continue
                if path and nums[i]<path[-1]:
                # 判断是否是非递减
                    continue
                path.append(nums[i])
                ans.append(path.copy())
                backtracking(i+1,path)
                path.pop()
            return
        backtracking(0, [])
        # 最后过滤长度小于2的元素
        filtered_ans = [path for path in ans if len(path)  >=2]
        return filtered_ans

这个代码通过了47 / 58。当遇到序列:[1,2,1,1] 时,其输出是:

[[1,2],[1,1],[1,1,1],[1,1]]

但实际上正确输出应该是

[[1,2],[1,1],[1,1,1]]

这很明显是去重出了问题。那对于这种abaa式的例子,为什么去重出了错误?
因为答案允许的是

a1 b
a1 a2
a1 a2 a3

但代码生成的是

a1 b
a1 a2
a1 a2 a3
a2 a3

其中,a1 a2a2 a3会被认定为重复。如果数组排序了,那么a1a2a3会因为相邻在一起,从而被

if i>cur and nums[i]==nums[i-1]:
      continue

去重。

但是因为数组没有排序,所以此时的去重只针对原本在数组中就相邻的重复元素,譬如a2a3,但因为a1与他们不相邻,所以没办法去掉。

所以,我们应该考虑换一种去重方式,关注到题目说-100 <= nums[i] <= 100,这实际上是个很小的范围,那么我们就可以使用专门的集合来记录已使用的数字

其次,我们要考虑的是这个集合应该定义在那,因为``

最后就是关于长度小于2的元素应该被过滤,实际上可以放在终止条件那,只有当path的长度≥2时,我们才收集结果。

终版代码

class Solution:
    def findSubsequences(self, nums: List[int]) -> List[List[int]]:
        if len(nums)<2:
            return []
        ans = []
        def backtracking(cur:int, path:List[int]) -> None:
            if len(path) >= 2:
                ans.append(path.copy())
            used = set() # 对每一层进行去重,而不是整体去重,否则收集不到如题目示例1中的[7,7]的答案
            for i in range(cur, len(nums)):
                if path and nums[i]<path[-1]:
                    continue
                if nums[i] in used:
                    continue
                used.add(nums[i])
                path.append(nums[i])
                backtracking(i+1,path)
                path.pop()
            return
        backtracking(0, [])
        return ans

效率:19ms,击败86.55%

这里还有一个非常小的优化点,就是我们要关注两个剪枝的顺序。

如果你把if nums[i] in used放在前面,其实效率差不多是20ms, 击败50%左右的程度。

这是因为if path and nums[i]<path[-1]分别是内置长度的判断,和直接索引判断两个数,是O(1)的复杂度。能筛选nums[i]是不是一个值得继续遍历的值。

如果是值得遍历的值,再去看它是否已经用过。所以将它放在前面效率会更快一点。

2.排列相关力扣题

46.全排列

因为这道题目给的数组序列不会有重复的数,所以我们直接定义一个used = set(),将用过的数字放进去,就能避免一个数字被重复使用。

class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        ans = []
        def backtracking(path:List[int], used:set) -> None:
            if len(path) == len(nums):
                ans.append(path.copy())
                return
            for i in range(len(nums)):
                if nums[i] in used:
                    continue
                path.append(nums[i])
                used.add(nums[i])
                backtracking(path, used)
                path.pop()
                used.remove(nums[i])
            return
        backtracking([], set())
        return ans

效率:3ms,击败46.01%

47.全排列Ⅱ

因为这道题的数组可能有重复的数,所以我们不能再单纯的定义一个used=set()了,因为这样会让我们漏掉答案。譬如对于数组nums:[1,1,2],我们如果放入了第一个数字1,那么第二个数字1就放不进去了,但是答案是有[1,1,2]的。所以我们应该将used定义为和数组等长的布尔数组。这样以下标来区分同一个数是否被重复使用。

from typing import List

class Solution:
    def permuteUnique(self, nums: List[int]) -> List[List[int]]:
        ans = []
        nums.sort()  # 对数组进行排序,方便去重
        used = [False] * len(nums)
        def backtracking(path: List[int], used) -> None:
            if len(path) == len(nums):
                ans.append(path.copy())
                return
            
            for i in range(len(nums)):
                # 当前下标的数已经被使用过了。这样会造成一个数字使用多次。
                if used[i]:
                    continue
                # 虽然当前下标的数没有被使用过,但是其前一个数和当前数的值一样,而且前一个数被使用过了
                # 这样会造成原数组重复的数组成重复的排列序列。
                if i>0 and used[i-1] and nums[i]==nums[i-1]:
                    continue
                path.append(nums[i])
                used[i] = True
                backtracking(path, used)
                path.pop()
                used[i] = False
            return
        
        backtracking([], used)
        return ans

效率:11ms,击败27.40%

同样的,其实46.全排列也可以将used设置成和数组等长的布尔数组的。

class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        ans = []
        used = [False] * len(nums)
        def backtracking(path:List[int], used) -> None:
            if len(path) == len(nums):
                ans.append(path.copy())
                return
            for i in range(len(nums)):
                if used[i]:
                    continue
                path.append(nums[i])
                used[i] = True
                backtracking(path, used)
                path.pop()
                used[i] = False
            return
        backtracking([], used)
        return ans

这样效率还会更快,0ms,击败100.00%

3.棋盘问题

51.N皇后(hard)

以下几点是重难点。

  1. 棋盘是二维的,也就是我们要遍历的集合是二维的。
  2. 判断皇后放置的位置是否合法。
  3. 在遍历第i行的时候,如何记得第i-2行及之前的状态,从而放置合法的位置(其实就是第二点)。

先来考虑第一个点,因为遍历的集合是二维的,相应的,我们的path也应该变成二维的。

其次是位置是否合法的问题,在这里有多个解决方案:

  • 我们可以存储三个set,分别存储一个位置的列、45°角,135°角是否重复的情况。
  • 或者我们直接遍历path去判断。

代码1:path为嵌套的字符列表,且通过遍历棋盘去判断位置合法

class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        ans = []
        # 初始化棋盘为n*n的字符列表以便修改
        path = [['.' for _ in range(n)] for _ in range(n)]
        
        def isValid(row, col):
            # 检查同一列是否有皇后
            for i in range(row):
                if path[i][col] == 'Q':
                    return False
            # 检查主对角线(左上)
            i, j = row - 1, col - 1
            while i >= 0 and j >= 0:
                if path[i][j] == 'Q':
                    return False
                i -= 1
                j -= 1
            # 检查副对角线(右上)
            i, j = row - 1, col + 1
            while i >= 0 and j < n:
                if path[i][j] == 'Q':
                    return False
                i -= 1
                j += 1
            return True
        
        def backtrack(cur_row):
            if cur_row == n:
                # 将字符列表转换为字符串形式
                ans.append([''.join(row) for row in path])
                return
            for col in range(n):
                if isValid(cur_row, col):
                    path[cur_row][col] = 'Q'  # 放置皇后
                    backtrack(cur_row + 1)    # 递归处理下一行
                    path[cur_row][col] = '.'  # 回溯,撤销皇后
        
        backtrack(0)
        return ans

效率:31ms,击败34.20%

可以看到,判断位置是否合法的代码非常冗长。


除此之外,path的初始化也非常重要。

注意到在这我们初始化时没有使用 path = ['.' * n for _ in range(n)],这是因为这样初始化后,path是一个字符串数组。例如,n=4时,如下:

path = ['. . . .', '. . . .', '. . . .', '. . . .']

由于字符串是不可变的,我们不能直接通过索引修改某个字符,那么如下代码会报错:

path[row][col] = 'Q'  # 这会报错!

不过在添加结果时要将内层字符列表转化为字符串:

ans.append([''.join(row) for row in path])

代码2:path为字符串列表,且通过遍历棋盘去判断位置是否合法

class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        ans = []
        # 初始化棋盘为字符串列表
        path = ['.' * n for _ in range(n)]
        
        def isValid(cur: int, col: int) -> bool:
            # 检查列是否有冲突
            for i in range(cur):
                if path[i][col] == 'Q':
                    return False
            
            # 检查45°对角线是否有冲突
            i, j = cur - 1, col - 1
            while i >= 0 and j >= 0:
                if path[i][j] == 'Q':
                    return False
                i -= 1
                j -= 1
            
            # 检查135°对角线是否有冲突
            i, j = cur - 1, col + 1
            while i >= 0 and j < n:
                if path[i][j] == 'Q':
                    return False
                i -= 1
                j += 1
            
            return True
        
        def backtracking(cur: int):
            """
            path存储棋盘的放置情况。cur代表当前遍历到第几行。
            """
            if cur == n:
            	# 直接path.copy()即可。
                ans.append(path.copy())
                return
            
            # 遍历当前行的每一列
            for col in range(n):
                if isValid(cur, col):
                    # 注意,因为path是字符串列表了,所以赋值如下
                    path[cur] = '.' * col + 'Q' + '.' * (n - col - 1)
                    backtracking(cur + 1)  # 递归到下一行
                    # 回溯,重置当前行
                    path[cur] = '.' * n
        
        backtracking(0)
        return ans

效率:35ms,击败29.82%

代码3:path为嵌套的字符列表,且使用三个set去判断位置合法(推荐)

主对角线的特点是行索引减去列索引的值相同(即 row - col 是常数)

譬如对于下标从0开始的4*4的棋盘,位置[0,0]和[1,1]是一条主对角线上的。同样的,[0,1]和[1,2]是一条主对角线上的。他们共同的特点就是行索引减去列索引的值相同。

副对角线的特点是行索引加上列索引的值相同(即 row + col 是常数)

譬如对于下标从0开始的4*4的棋盘,位置[1,1]和[0,2]是一条副对角线上的。同样的,[1,2]和[0,3]是一条副对角线上的。他们共同的特点就是行索引加上列索引的值相同。

所以代码如下:

class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        ans = []
        # 初始化棋盘
        path = [['.' for _ in range(n)] for _ in range(n)]
        
        # 用集合记录列、主对角线、副对角线是否被占用
        used_cols = set()          # 列
        used_diag1 = set()         # 主对角线(row - col)
        used_diag2 = set()         # 副对角线(row + col)
        
        def backtracking(row):
            if row == n:
                # 将当前棋盘状态转换为字符串列表
                ans.append([''.join(row) for row in path])
                return
            
            for col in range(n):
                # 检查列、主对角线、副对角线是否被占用
                if col in used_cols or (row - col) in used_diag1 or (row + col) in used_diag2:
                    continue  # 如果位置不合法,跳过当前位置
                # 放置皇后
                path[row][col] = 'Q'
                used_cols.add(col)
                used_diag1.add(row - col)
                used_diag2.add(row + col)
                
                # 递归处理下一行
                backtracking(row + 1)
                
                # 回溯,撤销皇后
                path[row][col] = '.'
                used_cols.remove(col)
                used_diag1.remove(row - col)
                used_diag2.remove(row + col)
        
        backtrack(0)
        return ans

效率:7ms,击败89.23%

可以看到,效率大大提升。

代码4:path为字符串列表,且使用三个set去判断位置是否合法

class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        ans = []
        # 初始化棋盘为字符串列表
        path = ['.' * n for _ in range(n)]
        
        # 用集合记录列、主对角线、副对角线是否被占用
        used_cols = set()          # 列
        used_diag1 = set()         # 主对角线(row - col)
        used_diag2 = set()         # 副对角线(row + col)
        
        def backtrack(row):
            if row == n:
                # 将当前棋盘状态保存到结果中
                ans.append(path[:])
                return
            
            for col in range(n):
                # 检查列、主对角线、副对角线是否被占用
                if col in used_cols or (row - col) in used_diag1 or (row + col) in used_diag2:
                    continue  # 如果冲突,跳过当前列
                
                # 放置皇后:创建一个新字符串
                new_row = path[row][:col] + 'Q' + path[row][col+1:]
                path[row] = new_row  # 更新当前行
                used_cols.add(col)
                used_diag1.add(row - col)
                used_diag2.add(row + col)
                
                # 递归处理下一行
                backtrack(row + 1)
                
                # 回溯,撤销皇后
                path[row] = '.' * n  # 恢复当前行
                used_cols.remove(col)
                used_diag1.remove(row - col)
                used_diag2.remove(row + col)
        
        backtrack(0)
        return ans

效率:17ms,击败44.10%

37.解数独(hard)

哎……太恐怖了,万一面试的时候要求手撕这道题,瓦塔西该怎么办哇呜呜呜┭┮﹏┭┮

首先,这道题的题目提示,这个棋盘是固定9*9大小的,而且有且仅有一个解

所以不同于以往的回溯题目,我们的backtracking函数是空返回的,这道题我们应该返回一个布尔值。一旦发现布尔值为真,就直接返回答案,不再需要继续搜索了。

而以往的回溯题目,譬如N皇后问题,是有多个结果的,我们必须要搜索完,才能拿到所有的结果,所以空返回。

其次,数独问题中,空格(需要填充的位置)分布在整个棋盘中,而不是逐行或逐列排列。而N皇后中,每一行一定存在一个皇后,无论这个皇后放置在第几列,。因此,我们需要同时遍历行和列,找到下一个需要填充的空格。

最后,因为这道题要填的是数字1-9,也就是说不仅要考虑一个位置的占用状态,还要考虑这个位置具体被什么数字占用了。像N皇后的话,就只需要考虑一个位置是否被占用了而已。所以这道题还需要一个递归,来递归数字

综合以上,这道题有三重递归

代码1

class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        def isValid(row, col, num):
            # 检查行是否重复
            for i in range(9):
                if board[row][i] == num:
                    return False
            # 检查列是否重复
            for i in range(9):
                if board[i][col] == num:
                    return False
            # 检查 3x3 宫格是否重复
            start_row = (row // 3) * 3
            start_col = (col // 3) * 3
            for i in range(3):
                for j in range(3):
                    if board[start_row + i][start_col + j] == num:
                        return False
            return True

        def backtrack():
            # 遍历每一行
            for row in range(9):
                # 遍历每一列
                for col in range(9):
                    # 如果当前位置是空格
                    if board[row][col] == '.':
                        # 尝试填充数字 1-9
                        for num in map(str, range(1, 10)):
                            if isValid(row, col, num):
                                # 填充数字
                                board[row][col] = num
                                # 递归填充下一个空格
                                if backtrack():
                                    return True
                                # 回溯,撤销填充
                                board[row][col] = '.'
                        # 如果 1-9 都无法填充,返回 False
                        return False
            # 如果所有空格都填充完毕,返回 True
            return True

        # 开始回溯
        backtrack()

最后发现通过了6/7。超出时间限制。

我们需要进行优化。在这我们发现,每次回溯时我们都会重新遍历整个棋盘,来寻找空格的位置。所以我们可以提前预收集空格的位置,每次回溯时只遍历这些位置即可。

代码2:预收集空格的位置

class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        # 预收集所有空格的位置
        empty_cells = []
        for row in range(9):
            for col in range(9):
                if board[row][col] == '.':
                    empty_cells.append((row, col))

        def isValid(row, col, num):
            # 检查行是否重复
            for i in range(9):
                if board[row][i] == num:
                    return False
            # 检查列是否重复
            for i in range(9):
                if board[i][col] == num:
                    return False
            # 检查 3x3 宫格是否重复
            start_row = (row // 3) * 3
            start_col = (col // 3) * 3
            for i in range(3):
                for j in range(3):
                    if board[start_row + i][start_col + j] == num:
                        return False
            return True

        def backtrack(index):
            """
            回溯函数,index 表示当前处理到第几个空格。
            """
            # 如果所有空格都填充完毕,返回 True
            if index == len(empty_cells):
                return True
            # 获取当前空格的行和列
            row, col = empty_cells[index]
            # 尝试填充数字 1-9
            for num in map(str, range(1, 10)):
                if isValid(row, col, num):
                    # 填充数字
                    board[row][col] = num
                    # 递归填充下一个空格
                    if backtrack(index + 1):
                        return True
                    # 回溯,撤销填充
                    board[row][col] = '.'
            # 如果 1-9 都无法填充,返回 False
            return False

        # 开始回溯
        backtrack(0)

发现还是只通过了6/7。超出时间限制。草啊。

我们继续优化,发现isValid函数的代码很冗长,我们是否可以优化?

答案是可以的,在 isValid 函数中,行、列和 3x3 宫格的检查是分开的,每次都需要遍历 9 个元素。我们可以尝试将行、列和宫格的检查合并到一个循环中,减少总的循环次数。

代码3:代码2优化版:合并行、列和宫格的检查到一个循环中

class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        # 预收集所有空格的位置
        empty_cells = []
        for row in range(9):
            for col in range(9):
                if board[row][col] == '.':
                    empty_cells.append((row, col))

        def isValid(row, col, num):
            """
            检查在 (row, col) 位置填充 num 是否合法。
            """
            # 预计算 3x3 宫格的起始行和起始列
            start_row = (row // 3) * 3
            start_col = (col // 3) * 3

            # 合并检查行、列和 3x3 宫格
            for i in range(9):
                # 检查行是否重复
                if board[row][i] == num:
                    return False
                # 检查列是否重复
                if board[i][col] == num:
                    return False
                # 检查 3x3 宫格是否重复
                if i < 3 and (board[start_row + i][start_col] == num or
                              board[start_row + i][start_col + 1] == num or
                              board[start_row + i][start_col + 2] == num):
                    return False

            return True

        def backtrack(index):
            """
            回溯函数,index 表示当前处理到第几个空格。
            """
            # 如果所有空格都填充完毕,返回 True
            if index == len(empty_cells):
                return True
            # 获取当前空格的行和列
            row, col = empty_cells[index]
            # 尝试填充数字 1-9
            for num in map(str, range(1, 10)):
                if isValid(row, col, num):
                    # 填充数字
                    board[row][col] = num
                    # 递归填充下一个空格
                    if backtrack(index + 1):
                        return True
                    # 回溯,撤销填充
                    board[row][col] = '.'
            # 如果 1-9 都无法填充,返回 False
            return False
            
        backtrack(0)

他大爷的,还是通过6/7,大悲……不过看示例,还是提升了3~4ms的。

代码4:代码3优化版:提前记录每行、每列和每个宫格中已存在的数

我们既然已经提前遍历原始棋盘,从而预收集了空格的位置。那么我们也可以提前记录每行、每列和每个宫格中已经存在的数字,从而避免重复检查。

为此,我们定义了三个set列表,分别存储每行、每列、每个宫格已经存在的数字。

但需要注意的是,我们需要将数独棋盘上的每个单元格映射到其对应的 3x3 宫格的编号。也就是row // 3 * 3 + col // 3的作用是确定当前单元格 (row, col) 所在的 3x3 宫格的索引。当然,改成row // 3+col // 3 * 3也是可以的。

class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        # 预收集所有空格的位置
        empty_cells = []
        rows = [set() for _ in range(9)]
        cols = [set() for _ in range(9)]
        boxes = [set() for _ in range(9)]

        for row in range(9):
            for col in range(9):
                if board[row][col] == '.':
                    empty_cells.append((row, col))
                else:
                    num = int(board[row][col])
                    rows[row].add(num)
                    cols[col].add(num)
                    boxes[row // 3 * 3 + col // 3].add(num)

        def backtrack(index):
            # 如果所有空格都填充完毕,返回 True
            if index == len(empty_cells):
                return True
            # 获取当前空格的行和列
            row, col = empty_cells[index]
            box_idx = row // 3 * 3 + col // 3
            # 尝试填充数字 1-9
            for num in range(1, 10):
                if num not in rows[row] and num not in cols[col] and num not in boxes[box_idx]:
                    # 更新状态
                    board[row][col] = str(num)
                    rows[row].add(num)
                    cols[col].add(num)
                    boxes[box_idx].add(num)
                    # 递归填充下一个空格
                    if backtrack(index + 1):
                        return True
                    # 回溯,撤销填充
                    board[row][col] = '.'
                    rows[row].remove(num)
                    cols[col].remove(num)
                    boxes[box_idx].remove(num)
            # 如果 1-9 都无法填充,返回 False
            return False

        backtrack(0)

效率:1639ms,击败51.86%

终于……通过了……呃呃呃……