安全聚合协议#
包括 SecAgg、SecAgg+ 和 LightSecAgg 协议。LightSecAgg 协议尚未实施,因此其图表和抽象在实践中可能并不准确。SecAgg 协议可视为 SecAgg+ 协议的特例。
代码:SecAgg+ 抽象#
在此实现中,将为每个客户端分配一个唯一索引(int),以确保聚合的安全性,因此使用的许多 python 字典的键都是 int 类型,而不是 ClientProxy 类型。
class SecAggPlusProtocol(ABC):
"""Abstract base class for the SecAgg+ protocol implementations."""
@abstractmethod
def generate_graph(
self, clients: List[ClientProxy], k: int
) -> ClientGraph:
"""Build a k-degree undirected graph of clients.
Each client will only generate pair-wise masks with its k neighbours.
k is equal to the number of clients in SecAgg, i.e., a complete graph.
This function may need extra inputs to decide on the generation of the graph."""
@abstractmethod
def setup_config(
self, clients: List[ClientProxy], config_dict: Dict[str, Scalar]
) -> SetupConfigResultsAndFailures:
"""Configure the next round of secure aggregation. (SetupConfigRes is an empty class.)"""
@abstractmethod
def ask_keys(
self,
clients: List[ClientProxy], ask_keys_ins_list: List[AskKeysIns]
) -> AskKeysResultsAndFailures:
"""Ask public keys. (AskKeysIns is an empty class, and hence ask_keys_ins_list can be omitted.)"""
@abstractmethod
def share_keys(
self,
clients: List[ClientProxy], public_keys_dict: Dict[int, AskKeysRes],
graph: ClientGraph
) -> ShareKeysResultsAndFailures:
"""Send public keys."""
@abstractmethod
def ask_vectors(
clients: List[ClientProxy],
forward_packet_list_dict: Dict[int, List[ShareKeysPacket]],
client_instructions=None: Dict[int, FitIns]
) -> AskVectorsResultsAndFailures:
"""Ask vectors of local model parameters.
(If client_instructions is not None, local models will be trained in the ask vectors stage,
rather than trained parallelly as the protocol goes through the previous stages.)"""
@abstractmethod
def unmask_vectors(
clients: List[ClientProxy],
dropout_clients: List[ClientProxy],
graph: ClientGraph
) -> UnmaskVectorsResultsAndFailures:
"""Unmask and compute the aggregated model. UnmaskVectorRes contains shares of keys needed to generate masks."""
Flower 服务器将按以下顺序执行和处理收到的结果:
sequenceDiagram
participant S as Flower Server
participant P as SecAgg+ Protocol
participant C1 as Flower Client
participant C2 as Flower Client
participant C3 as Flower Client
S->>P: generate_graph
activate P
P-->>S: client_graph
deactivate P
Note left of P: Stage 0:<br/>Setup Config
rect rgb(249, 219, 130)
S->>P: setup_config<br/>clients, config_dict
activate P
P->>C1: SetupConfigIns
deactivate P
P->>C2: SetupConfigIns
P->>C3: SetupConfigIns
C1->>P: SetupConfigRes (empty)
C2->>P: SetupConfigRes (empty)
C3->>P: SetupConfigRes (empty)
activate P
P-->>S: None
deactivate P
end
Note left of P: Stage 1:<br/>Ask Keys
rect rgb(249, 219, 130)
S->>P: ask_keys<br/>clients
activate P
P->>C1: AskKeysIns (empty)
deactivate P
P->>C2: AskKeysIns (empty)
P->>C3: AskKeysIns (empty)
C1->>P: AskKeysRes
C2->>P: AskKeysRes
C3->>P: AskKeysRes
activate P
P-->>S: public keys
deactivate P
end
Note left of P: Stage 2:<br/>Share Keys
rect rgb(249, 219, 130)
S->>P: share_keys<br/>clients, public_keys_dict,<br/>client_graph
activate P
P->>C1: ShareKeysIns
deactivate P
P->>C2: ShareKeysIns
P->>C3: ShareKeysIns
C1->>P: ShareKeysRes
C2->>P: ShareKeysRes
C3->>P: ShareKeysRes
activate P
P-->>S: encryted key shares
deactivate P
end
Note left of P: Stage 3:<br/>Ask Vectors
rect rgb(249, 219, 130)
S->>P: ask_vectors<br/>clients,<br/>forward_packet_list_dict
activate P
P->>C1: AskVectorsIns
deactivate P
P->>C2: AskVectorsIns
P->>C3: AskVectorsIns
C1->>P: AskVectorsRes
C2->>P: AskVectorsRes
activate P
P-->>S: masked vectors
deactivate P
end
Note left of P: Stage 4:<br/>Unmask Vectors
rect rgb(249, 219, 130)
S->>P: unmask_vectors<br/>clients, dropped_clients,<br/>client_graph
activate P
P->>C1: UnmaskVectorsIns
deactivate P
P->>C2: UnmaskVectorsIns
C1->>P: UnmaskVectorsRes
C2->>P: UnmaskVectorsRes
activate P
P-->>S: key shares
deactivate P
end
代码:LightSecAgg 抽象#
在此实现中,将为每个客户端分配一个唯一索引(int),以确保聚合的安全性,因此使用的许多 python 字典的键都是 int 类型,而不是 ClientProxy 类型。
class LightSecAggProtocol(ABC):
"""Abstract base class for the LightSecAgg protocol implementations."""
@abstractmethod
def setup_config(
self, clients: List[ClientProxy], config_dict: Dict[str, Scalar]
) -> LightSecAggSetupConfigResultsAndFailures:
"""Configure the next round of secure aggregation."""
@abstractmethod
def ask_encrypted_encoded_masks(
self,
clients: List[ClientProxy], public_keys_dict: Dict[int, LightSecAggSetupConfigRes]
) -> AskEncryptedEncodedMasksResultsAndFailures:
"""Ask encrypted encoded masks. The protocol adopts Diffie-Hellman keys to build pair-wise secured channels to transfer encoded mask."""
@abstractmethod
def ask_masked_models(
self,
clients: List[ClientProxy],
forward_packet_list_dict: Dict[int, List[EncryptedEncodedMasksPacket]],
client_instructions=None: Dict[int, FitIns]
) -> AskMaskedModelsResultsAndFailures:
"""Ask the masked local models.
(If client_instructions is not None, local models will be trained in the ask vectors stage,
rather than trained parallelly as the protocol goes through the previous stages.)"""
@abstractmethod
def ask_aggregated_encoded_masks(
clients: List[ClientProxy]
) -> AskAggregatedEncodedMasksResultsAndFailures:
"""Ask aggregated encoded masks"""
Flower 服务器将按以下顺序执行和处理收到的结果:
sequenceDiagram
participant S as Flower Server
participant P as LightSecAgg Protocol
participant C1 as Flower Client
participant C2 as Flower Client
participant C3 as Flower Client
Note left of P: Stage 0:<br/>Setup Config
rect rgb(249, 219, 130)
S->>P: setup_config<br/>clients, config_dict
activate P
P->>C1: LightSecAggSetupConfigIns
deactivate P
P->>C2: LightSecAggSetupConfigIns
P->>C3: LightSecAggSetupConfigIns
C1->>P: LightSecAggSetupConfigRes
C2->>P: LightSecAggSetupConfigRes
C3->>P: LightSecAggSetupConfigRes
activate P
P-->>S: public keys
deactivate P
end
Note left of P: Stage 1:<br/>Ask Encrypted Encoded Masks
rect rgb(249, 219, 130)
S->>P: ask_encrypted_encoded_masks<br/>clients, public_keys_dict
activate P
P->>C1: AskEncryptedEncodedMasksIns
deactivate P
P->>C2: AskEncryptedEncodedMasksIns
P->>C3: AskEncryptedEncodedMasksIns
C1->>P: AskEncryptedEncodedMasksRes
C2->>P: AskEncryptedEncodedMasksRes
C3->>P: AskEncryptedEncodedMasksRes
activate P
P-->>S: forward packets
deactivate P
end
Note left of P: Stage 2:<br/>Ask Masked Models
rect rgb(249, 219, 130)
S->>P: share_keys<br/>clients, forward_packet_list_dict
activate P
P->>C1: AskMaskedModelsIns
deactivate P
P->>C2: AskMaskedModelsIns
P->>C3: AskMaskedModelsIns
C1->>P: AskMaskedModelsRes
C2->>P: AskMaskedModelsRes
activate P
P-->>S: masked local models
deactivate P
end
Note left of P: Stage 3:<br/>Ask Aggregated Encoded Masks
rect rgb(249, 219, 130)
S->>P: ask_aggregated_encoded_masks<br/>clients
activate P
P->>C1: AskAggregatedEncodedMasksIns
deactivate P
P->>C2: AskAggregatedEncodedMasksIns
C1->>P: AskAggregatedEncodedMasksRes
C2->>P: AskAggregatedEncodedMasksRes
activate P
P-->>S: the aggregated model
deactivate P
end
类型#
# the SecAgg+ protocol
ClientGraph = Dict[int, List[int]]
SetupConfigResultsAndFailures = Tuple[
List[Tuple[ClientProxy, SetupConfigRes]], List[BaseException]
]
AskKeysResultsAndFailures = Tuple[
List[Tuple[ClientProxy, AskKeysRes]], List[BaseException]
]
ShareKeysResultsAndFailures = Tuple[
List[Tuple[ClientProxy, ShareKeysRes]], List[BaseException]
]
AskVectorsResultsAndFailures = Tuple[
List[Tuple[ClientProxy, AskVectorsRes]], List[BaseException]
]
UnmaskVectorsResultsAndFailures = Tuple[
List[Tuple[ClientProxy, UnmaskVectorsRes]], List[BaseException]
]
FitResultsAndFailures = Tuple[
List[Tuple[ClientProxy, FitRes]], List[BaseException]
]
@dataclass
class SetupConfigIns:
sec_agg_cfg_dict: Dict[str, Scalar]
@dataclass
class SetupConfigRes:
pass
@dataclass
class AskKeysIns:
pass
@dataclass
class AskKeysRes:
"""Ask Keys Stage Response from client to server"""
pk1: bytes
pk2: bytes
@dataclass
class ShareKeysIns:
public_keys_dict: Dict[int, AskKeysRes]
@dataclass
class ShareKeysPacket:
source: int
destination: int
ciphertext: bytes
@dataclass
class ShareKeysRes:
share_keys_res_list: List[ShareKeysPacket]
@dataclass
class AskVectorsIns:
ask_vectors_in_list: List[ShareKeysPacket]
fit_ins: FitIns
@dataclass
class AskVectorsRes:
parameters: Parameters
@dataclass
class UnmaskVectorsIns:
available_clients: List[int]
dropout_clients: List[int]
@dataclass
class UnmaskVectorsRes:
share_dict: Dict[int, bytes]
# the LightSecAgg protocol
LightSecAggSetupConfigResultsAndFailures = Tuple[
List[Tuple[ClientProxy, LightSecAggSetupConfigRes]], List[BaseException]
]
AskEncryptedEncodedMasksResultsAndFailures = Tuple[
List[Tuple[ClientProxy, AskEncryptedEncodedMasksRes]], List[BaseException]
]
AskMaskedModelsResultsAndFailures = Tuple[
List[Tuple[ClientProxy, AskMaskedModelsRes]], List[BaseException]
]
AskAggregatedEncodedMasksResultsAndFailures = Tuple[
List[Tuple[ClientProxy, AskAggregatedEncodedMasksRes]], List[BaseException]
]
@dataclass
class LightSecAggSetupConfigIns:
sec_agg_cfg_dict: Dict[str, Scalar]
@dataclass
class LightSecAggSetupConfigRes:
pk: bytes
@dataclass
class AskEncryptedEncodedMasksIns:
public_keys_dict: Dict[int, LightSecAggSetupConfigRes]
@dataclass
class EncryptedEncodedMasksPacket:
source: int
destination: int
ciphertext: bytes
@dataclass
class AskEncryptedEncodedMasksRes:
packet_list: List[EncryptedEncodedMasksPacket]
@dataclass
class AskMaskedModelsIns:
packet_list: List[EncryptedEncodedMasksPacket]
fit_ins: FitIns
@dataclass
class AskMaskedModelsRes:
parameters: Parameters
@dataclass
class AskAggregatedEncodedMasksIns:
surviving_clients: List[int]
@dataclass
class AskAggregatedEncodedMasksRes:
aggregated_encoded_mask: Parameters