Skip to content

Commit 5338b87

Browse files
committed
solved(python): baekjoon 1761
1 parent fb12a68 commit 5338b87

4 files changed

Lines changed: 128 additions & 0 deletions

File tree

baekjoon/python/1761/__init__.py

Whitespace-only changes.

baekjoon/python/1761/main.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import sys
2+
from collections import defaultdict, deque
3+
from typing import cast
4+
5+
read = lambda: sys.stdin.readline().rstrip()
6+
7+
8+
class Problem:
9+
def __init__(self):
10+
self.n = int(read())
11+
12+
self.nodes = defaultdict(set[tuple[int, int]])
13+
for _ in range(self.n - 1):
14+
src, dest, distance = map(int, read().split())
15+
self.nodes[src].add((dest, distance))
16+
self.nodes[dest].add((src, distance))
17+
18+
self.queries = [cast(tuple[int, int], tuple(map(int, read().split()))) for _ in range(int(read()))]
19+
20+
self.depth, self.parent, self.weights = self.make_tree()
21+
22+
def solve(self) -> None:
23+
for x, y in self.queries:
24+
print(self.weights[x] + self.weights[y] - 2 * self.weights[self.lca(x, y)])
25+
26+
def make_tree(self) -> tuple[list[int], list[list[int]], list[int]]:
27+
queue, visited = deque([(1, 0)]), {1}
28+
depth, parent, weights = (
29+
[0 for _ in range(self.n + 1)],
30+
[[0 for _ in range(20)] for _ in range(self.n + 1)],
31+
[0 for _ in range(self.n + 1)],
32+
)
33+
34+
while queue:
35+
node, idx = queue.popleft()
36+
depth[node] = idx
37+
38+
for child, distance in self.nodes[node]:
39+
if child not in visited:
40+
queue.append((child, idx + 1))
41+
visited.add(child)
42+
parent[child][0] = node
43+
weights[child] = weights[node] + distance
44+
45+
for k in range(1, 20):
46+
for node in range(1, self.n + 1):
47+
parent[node][k] = parent[parent[node][k - 1]][k - 1]
48+
49+
return depth, parent, weights
50+
51+
def lca(self, x: int, y: int) -> int:
52+
if self.depth[x] < self.depth[y]:
53+
x, y = y, x
54+
55+
for k in range(20)[::-1]:
56+
if self.depth[x] - (1 << k) >= self.depth[y]:
57+
x = self.parent[x][k]
58+
59+
if x == y:
60+
return x
61+
62+
for k in range(20)[::-1]:
63+
if self.parent[x][k] != self.parent[y][k]:
64+
x, y = self.parent[x][k], self.parent[y][k]
65+
66+
return self.parent[x][0]
67+
68+
69+
if __name__ == "__main__":
70+
Problem().solve()

baekjoon/python/1761/sample.json

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[
2+
{
3+
"input": [
4+
"7",
5+
"1 6 13",
6+
"6 3 9",
7+
"3 5 7",
8+
"4 1 3",
9+
"2 4 20",
10+
"4 7 2",
11+
"3",
12+
"1 6",
13+
"1 4",
14+
"2 6"
15+
],
16+
"expected": [
17+
"13",
18+
"3",
19+
"36"
20+
]
21+
}
22+
]

baekjoon/python/1761/test_main.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import json
2+
import os.path
3+
import unittest
4+
from io import StringIO
5+
from unittest.mock import patch
6+
7+
from parameterized import parameterized
8+
9+
from main import Problem
10+
11+
12+
def load_sample(filename: str):
13+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), filename)
14+
15+
with open(path, "r") as file:
16+
return [(case["input"], case["expected"]) for case in json.load(file)]
17+
18+
19+
class TestCase(unittest.TestCase):
20+
@parameterized.expand(load_sample("sample.json"))
21+
def test_case(self, case: str, expected: list[str]):
22+
# When
23+
with (
24+
patch("sys.stdin.readline", side_effect=case),
25+
patch("sys.stdout", new_callable=StringIO) as output,
26+
):
27+
Problem().solve()
28+
29+
result = output.getvalue().rstrip()
30+
31+
# Then
32+
self.assertEqual("\n".join(expected), result)
33+
34+
35+
if __name__ == "__main__":
36+
unittest.main()

0 commit comments

Comments
 (0)