mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 10:21:52 +00:00
250 lines
7.0 KiB
Rust
250 lines
7.0 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::net::{Ipv4Addr, SocketAddr};
|
|
|
|
use imbl_value::InternedString;
|
|
use ipnet::Ipv4Net;
|
|
use itertools::Itertools;
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::process::Command;
|
|
use x25519_dalek::{PublicKey, StaticSecret};
|
|
|
|
use crate::prelude::*;
|
|
use crate::util::Invoke;
|
|
use crate::util::io::write_file_atomic;
|
|
use crate::util::serde::Base64;
|
|
|
|
pub const WIREGUARD_INTERFACE_NAME: &str = "wg-start-tunnel";
|
|
|
|
#[derive(Deserialize, Serialize, HasModel)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[model = "Model<Self>"]
|
|
pub struct WgServer {
|
|
pub port: u16,
|
|
pub key: Base64<WgKey>,
|
|
pub subnets: WgSubnetMap,
|
|
}
|
|
impl Default for WgServer {
|
|
fn default() -> Self {
|
|
Self {
|
|
port: 51820,
|
|
key: Base64(WgKey::generate()),
|
|
subnets: WgSubnetMap::default(),
|
|
}
|
|
}
|
|
}
|
|
impl WgServer {
|
|
pub fn server_config<'a>(&'a self) -> ServerConfig<'a> {
|
|
ServerConfig(self)
|
|
}
|
|
pub async fn sync(&self) -> Result<(), Error> {
|
|
Command::new("wg-quick")
|
|
.arg("down")
|
|
.arg(WIREGUARD_INTERFACE_NAME)
|
|
.invoke(ErrorKind::Network)
|
|
.await
|
|
.or_else(|e| {
|
|
let msg = e.source.to_string();
|
|
if msg.contains("does not exist") || msg.contains("is not a WireGuard interface") {
|
|
Ok(Vec::new())
|
|
} else {
|
|
Err(e)
|
|
}
|
|
})?;
|
|
write_file_atomic(
|
|
const_format::formatcp!("/etc/wireguard/{WIREGUARD_INTERFACE_NAME}.conf"),
|
|
self.server_config().to_string().as_bytes(),
|
|
)
|
|
.await?;
|
|
Command::new("wg-quick")
|
|
.arg("up")
|
|
.arg(WIREGUARD_INTERFACE_NAME)
|
|
.invoke(ErrorKind::Network)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[derive(Default, Deserialize, Serialize)]
|
|
pub struct WgSubnetMap(pub BTreeMap<Ipv4Net, WgSubnetConfig>);
|
|
impl Map for WgSubnetMap {
|
|
type Key = Ipv4Net;
|
|
type Value = WgSubnetConfig;
|
|
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
|
|
Self::key_string(key)
|
|
}
|
|
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
|
|
Ok(InternedString::from_display(key))
|
|
}
|
|
}
|
|
|
|
#[derive(Default, Deserialize, Serialize, HasModel)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[model = "Model<Self>"]
|
|
pub struct WgSubnetConfig {
|
|
pub name: InternedString,
|
|
pub clients: WgSubnetClients,
|
|
}
|
|
impl WgSubnetConfig {
|
|
pub fn new(name: InternedString) -> Self {
|
|
Self {
|
|
name,
|
|
..Self::default()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Default, Deserialize, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct WgSubnetClients(pub BTreeMap<Ipv4Addr, WgConfig>);
|
|
impl Map for WgSubnetClients {
|
|
type Key = Ipv4Addr;
|
|
type Value = WgConfig;
|
|
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
|
|
Self::key_string(key)
|
|
}
|
|
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
|
|
Ok(InternedString::from_display(key))
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct WgKey(StaticSecret);
|
|
impl WgKey {
|
|
pub fn generate() -> Self {
|
|
Self(StaticSecret::random_from_rng(
|
|
ssh_key::rand_core::OsRng::default(),
|
|
))
|
|
}
|
|
}
|
|
impl AsRef<[u8]> for WgKey {
|
|
fn as_ref(&self) -> &[u8] {
|
|
self.0.as_bytes()
|
|
}
|
|
}
|
|
impl TryFrom<Vec<u8>> for WgKey {
|
|
type Error = Error;
|
|
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
|
|
Ok(Self(
|
|
<[u8; 32]>::try_from(value)
|
|
.map_err(|_| Error::new(eyre!("invalid key length"), ErrorKind::Deserialization))?
|
|
.into(),
|
|
))
|
|
}
|
|
}
|
|
impl std::ops::Deref for WgKey {
|
|
type Target = StaticSecret;
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
impl Base64<WgKey> {
|
|
pub fn verifying_key(&self) -> Base64<PublicKey> {
|
|
Base64((&*self.0).into())
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Deserialize, Serialize, HasModel)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[model = "Model<Self>"]
|
|
pub struct WgConfig {
|
|
pub name: InternedString,
|
|
pub key: Base64<WgKey>,
|
|
pub psk: Base64<[u8; 32]>,
|
|
}
|
|
impl WgConfig {
|
|
pub fn generate(name: InternedString) -> Self {
|
|
Self {
|
|
name,
|
|
key: Base64(WgKey::generate()),
|
|
psk: Base64(rand::random()),
|
|
}
|
|
}
|
|
pub fn server_peer_config<'a>(&'a self, addr: Ipv4Addr) -> ServerPeerConfig<'a> {
|
|
ServerPeerConfig {
|
|
client_config: self,
|
|
client_addr: addr,
|
|
}
|
|
}
|
|
pub fn client_config(
|
|
self,
|
|
addr: Ipv4Addr,
|
|
server_pubkey: Base64<PublicKey>,
|
|
server_addr: SocketAddr,
|
|
) -> ClientConfig {
|
|
ClientConfig {
|
|
client_config: self,
|
|
client_addr: addr,
|
|
server_pubkey,
|
|
server_addr,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ServerPeerConfig<'a> {
|
|
client_config: &'a WgConfig,
|
|
client_addr: Ipv4Addr,
|
|
}
|
|
impl<'a> std::fmt::Display for ServerPeerConfig<'a> {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(
|
|
f,
|
|
include_str!("./server-peer.conf.template"),
|
|
pubkey = self.client_config.key.verifying_key().to_padded_string(),
|
|
psk = self.client_config.psk.to_padded_string(),
|
|
addr = self.client_addr,
|
|
)
|
|
}
|
|
}
|
|
|
|
fn deserialize_verifying_key<'de, D>(deserializer: D) -> Result<Base64<PublicKey>, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
Base64::<Vec<u8>>::deserialize(deserializer).and_then(|b| {
|
|
Ok(Base64(PublicKey::from(<[u8; 32]>::try_from(b.0).map_err(
|
|
|e: Vec<u8>| serde::de::Error::invalid_length(e.len(), &"a 32 byte base64 string"),
|
|
)?)))
|
|
})
|
|
}
|
|
|
|
#[derive(Clone, Serialize, Deserialize)]
|
|
pub struct ClientConfig {
|
|
client_config: WgConfig,
|
|
client_addr: Ipv4Addr,
|
|
#[serde(deserialize_with = "deserialize_verifying_key")]
|
|
server_pubkey: Base64<PublicKey>,
|
|
server_addr: SocketAddr,
|
|
}
|
|
impl std::fmt::Display for ClientConfig {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(
|
|
f,
|
|
include_str!("./client.conf.template"),
|
|
name = self.client_config.name,
|
|
privkey = self.client_config.key.to_padded_string(),
|
|
psk = self.client_config.psk.to_padded_string(),
|
|
addr = self.client_addr,
|
|
server_pubkey = self.server_pubkey.to_padded_string(),
|
|
server_addr = self.server_addr,
|
|
)
|
|
}
|
|
}
|
|
|
|
pub struct ServerConfig<'a>(&'a WgServer);
|
|
impl<'a> std::fmt::Display for ServerConfig<'a> {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
let Self(server) = *self;
|
|
write!(
|
|
f,
|
|
include_str!("./server.conf.template"),
|
|
subnets = server.subnets.0.keys().join(", "),
|
|
server_port = server.port,
|
|
server_privkey = server.key.to_padded_string(),
|
|
)?;
|
|
for (addr, peer) in server.subnets.0.values().flat_map(|s| &s.clients.0) {
|
|
write!(f, "{}", peer.server_peer_config(*addr))?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|