#!/usr/bin/env python3
from typing import List

line = input().split(" ")
n = int(line[0])
e = int(line[1])
entrances = [int(e) for e in input().split(" ")]

forest = [[] for _ in range(n + 1)]
parent = [0] * (n + 1)
for i in range(1, n + 1):
    line = input().split(" ")
    if line[0] == "0":
        continue
    for child in line[1:]:
        c = int(child)
        forest[i].append(c)
        parent[c] = i


def iterative_count(entrance: int, forest: List[List[int]], parent: List[int]):
    if len(forest[entrance]) == 0:
        return 1, 0

    dimensions = [(0, 0) for _ in range(n + 1)]
    stack = [entrance]
    children = [0] * (n + 1)
    children[0] = -1
    children[entrance] = len(forest[entrance])

    while len(stack) > 0:
        current = stack.pop()
        skip = False

        while children[current] == 0 and current != entrance:
            skip = True
            p = parent[current]
            pw, pl = dimensions[p]
            cw, cl = dimensions[current]
            pw += cw
            pl += cw + cl
            dimensions[p] = pw, pl
            children[p] -= 1
            current = p

        if skip:
            continue

        for child in forest[current]:
            stack.append(child)
            children[child] = len(forest[child])
            if children[child] == 0:
                dimensions[child] = 1, 0

    w = 0
    l = 0
    for child in forest[entrance]:
        cw, cl = dimensions[child]
        w += cw
        l += cw + cl

    return w, l


ans = 0
for entrance in entrances:
    ans += iterative_count(entrance, forest, parent)[1]

print(ans * 2)
