티스토리 뷰

오늘 리뷰할 문제는 binary tree maximm path sum입니다. https://leetcode.com/problems/binary-tree-maximum-path-sum/

 

Binary Tree Maximum Path Sum - LeetCode

Can you solve this real interview question? Binary Tree Maximum Path Sum - A path in a binary tree is a sequence of nodes where each pair of adjacent nodes in the sequence has an edge connecting them. A node can only appear in the sequence at most once. No

leetcode.com

 

이 문제는 이진 트리가 주어졌을 때, 중복된 노드를 순회하지 않는 경로 중, 경로의 node value의 합이 최대인 값을 구하는 문제입니다.

 

이 문제를 풀었을 때도, dp로 풀어야 겠다고 생각하고, 접근을 했다.

 

아래와 같이 dp 배열을 정의하고 나면

max[node] = 해당 node를 root로 하는 이진 트리에서 maximum path sum의 값 

append_max[node] = 해당 node를 root로 하면서 위로 연결될 수 있는 path중 maximum path의 값

 

관계식을 아래처럼 쓸 수있다.

append_max[node] = max(apend_max[node.left], append_max[node.right], 0) + node.val

max[node] = max(max[node.left], max[node.right], append_max[node], append_max[node.left] + append_max[node.right] + node.val)

 

이 때쯤 깨달은 사실, 한 node 순회를 한번만 하기 때문에 dp로 풀 필요는 없다.

 

위 관계식을 모든 노드에 대해서 dfs로 순회를 하면 아래와 같이 코드가 나온다.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right


from typing import Optional


class Solution:
    def traverseNode(self, node: TreeNode) -> (int, int):
        # 자식이 없을 때는 바로 내려준다.
        if node.left is None and node.right is None:
            return node.val, node.val
        elif node.left is None or node.right is None:
            non_empty_node = node.left or node.right
            child_max, child_append_max = self.traverseNode(non_empty_node)
            append_max_val = max(child_append_max, 0) + node.val
            max_val = max(child_max, append_max_val)
            return max_val, append_max_val
        else:
            left_max, left_append_max = self.traverseNode(node.left)
            right_max, right_append_max = self.traverseNode(node.right)
            append_max = max(left_append_max, right_append_max, 0) + node.val
            max_val = max(
                left_max,
                right_max,
                append_max,
                left_append_max + right_append_max + node.val,
            )
            return max_val, append_max

    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        max_val, _ = self.traverseNode(root)
        return max_val


assert Solution().maxPathSum(TreeNode(1, TreeNode(2), TreeNode(3))) == 6
assert (
    Solution().maxPathSum(
        TreeNode(-10, TreeNode(9), TreeNode(20, TreeNode(15), TreeNode(7)))
    )
    == 42
)
assert Solution().maxPathSum(TreeNode(-3, TreeNode(-1))) == -1
assert Solution().maxPathSum(TreeNode(-3)) == -3

 

이렇게 제출을 해보면 runtime은 상위 5%, 메모리 사용량은 상위 95%가 나오는데.

 

분포까지 보면, 아래처럼 보인다.

로직은 크게 문제가 없지만, memory 사용량 측면에서 아쉬워서 다른 사람의 코드를 봤을 때.

 

1. 순회할 때 저장해야하는 데이터를 두가지 종류로 했었는데, append_max만 관리를 하고. 전체 max는 매 순회때마다 저장을 하면 된다는 사실을 알 수 있었다.

2. edge case(children이 없는 노드)가 있는 경우 어떻게 처리를 해야할 지 고민이 었는데, node가 없을 경우 0을 return하면 된다.

 

변경된 solution은 다음과 같다.

class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        max_sum = [-1001]

        def traverseNode(node: Optional[TreeNode]) -> int:
            if node is None:
                return 0

            left_max = max(traverseNode(node.left), 0)
            right_max = max(traverseNode(node.right), 0)

            max_sum[0] = max(max_sum[0], left_max + right_max + node.val)

            return max(left_max, right_max) + node.val

        traverseNode(root)
        return max_sum[0]

 

이렇게 하니 runtime이 생각보다 많이 느려졌는데, array를 사용해서 데이터를 주고 받는게 아무래도 시간을 더 쓰는 부분이 된 것 같다.

memory 사용량의 원인이 변수 하나 더 있는게 아니라, dfs로 순회하는 부분인 것 같다고 생각이 번뜩 스쳤다.

 

일단 dfs + function stack memory 사용 + leaf node 순회 간략화된 코드 버젼으로 마무리했다.

class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        def traverseNode(node: Optional[TreeNode]) -> (int, int):
            if node is None:
                return -1001, 0

            lm, lam = traverseNode(node.left)
            rm, ram = traverseNode(node.right)

            cam = max(lam, ram, 0) + node.val
            cm = max(lm, rm, lam + ram + node.val, cam)

            return cm, cam

        m, _ = traverseNode(root)
        return m

 

최종 결과는

 

코드를 짜는 것도 재밌지만, 발전시키는것도 재미가 들어가는 것 같다.

반응형