Python类型:连接序列
在 python 中,两个序列的连接通常由+操作符完成。但是,mypy 抱怨以下内容:
from typing import Sequence
def concat1(a: Sequence, b: Sequence) -> Sequence:
return a + b
没错:Sequence没有__add__. 但是,该函数对于“通常的”序列类型list, str,工作得非常好tuple。显然,还有其他序列类型不起作用(例如numpy.ndarray)。解决方案可能是以下内容:
from itertools import chain
def concat2(a: Sequence, b: Sequence) -> Sequence:
return list(chain(a, b))
现在,mypy 没有抱怨。但是连接字符串或元组总是给出一个列表。似乎有一个简单的解决方法:
def concat3(a: Sequence, b: Sequence) -> Sequence:
T = type(a)
return T(chain(a, b))
但是现在 mypy 不高兴了,因为 T get 的构造函数参数太多。更糟糕的是,该函数不再返回序列,而是返回一个生成器。
这样做的正确方法是什么?我觉得问题的一部分是 a 和 b 应该具有相同的类型,并且输出也将是相同的类型,但是类型注释并没有传达它。
注意:我知道使用''.join(a, b). 但是,我选择这个例子更多是为了说明目的。
回答
没有通用的方法来解决这个问题:Sequence包括不能以通用方式连接的类型。例如,无法连接任意range对象以创建新对象range并保留所有元素。
必须决定一种具体的连接方式,并将可接受的类型限制为提供所需操作的类型。
最简单的方法是让函数只请求所需的操作。如果预构建的协议typing不够用,可以回退到typing.Protocol为请求的操作定义自定义。
由于concat1/concat_add需要+实现,所以需要Protocolwith __add__。此外,由于加法通常适用于相似类型,因此__add__必须在具体类型上进行参数化——否则,协议要求所有可添加类型都可以添加到所有其他可添加类型。
# TypeVar to parameterize for specific types
SA = TypeVar('SA', bound='SupportsAdd')
class SupportsAdd(Protocol):
"""Any type T where +(:T, :T) -> T"""
def __add__(self: SA, other: SA) -> SA: ...
def concat_add(a: SA, b: SA) -> SA:
return a + b
这足以类型安全地连接基本序列,并拒绝混合类型连接。
reveal_type(concat_add([1, 2, 3], [12, 17])) # note: Revealed type is 'builtins.list*[builtins.int]'
reveal_type(concat_add("abc", "xyz")) # note: Revealed type is 'builtins.str*'
reveal_type(concat_add([1, 2, 3], "xyz")) # error: ...
请注意,这允许连接实现 的任何类型,__add__例如int。如果需要进一步的限制,请更仔细地定义协议 - 例如通过要求__len__和__getitem__。
通过链接键入连接有点复杂,但遵循相同的方法:AProtocol定义函数所需的功能,但为了类型安全,元素也应该被键入。
# TypeVar to parameterize for specific types and element types
C = TypeVar('C', bound='Chainable')
T = TypeVar('T', covariant=True)
# Parameterized by the element type T
class Chainable(Protocol[T]):
"""Any type C[T] where C[T](:Iterable[T]) -> C[T] and iter(:C[T]) -> Iterable[T]"""
def __init__(self, items: Iterable[T]): ...
def __iter__(self) -> Iterator[T]: ...
def concat_chain(a: C, b: C) -> C:
T = type(a)
return T(chain(a, b))
这足以类型安全地连接从它们自身构造的序列,并拒绝混合类型的连接和非序列。
reveal_type(concat_chain([1, 2, 3], [12, 17])) # note: Revealed type is 'builtins.list*[builtins.int]'
reveal_type(concat_chain("abc", "xyz")) # note: Revealed type is 'builtins.str*'
reveal_type(concat_chain([1, 2, 3], "xyz")) # error: ...
reveal_type(concat_chain(1, 2)) # error: ...