diff --git a/core/startos/src/tunnel/api.rs b/core/startos/src/tunnel/api.rs index af8a48831..e2e60ffe7 100644 --- a/core/startos/src/tunnel/api.rs +++ b/core/startos/src/tunnel/api.rs @@ -141,17 +141,19 @@ pub struct AddSubnetParams { pub async fn add_subnet( ctx: TunnelContext, AddSubnetParams { name }: AddSubnetParams, - SubnetParams { subnet }: SubnetParams, + SubnetParams { mut subnet }: SubnetParams, ) -> Result<(), Error> { - if subnet.addr().octets()[3] == 0 - || subnet.addr().octets()[3] == 255 - || subnet.prefix_len() > 24 - { + if subnet.prefix_len() > 24 { return Err(Error::new( eyre!("invalid subnet"), ErrorKind::InvalidRequest, )); } + let addr = subnet + .hosts() + .next() + .ok_or_else(|| Error::new(eyre!("invalid subnet"), ErrorKind::InvalidRequest))?; + subnet = Ipv4Net::new_assert(addr, subnet.prefix_len()); let server = ctx .db .mutate(|db| { @@ -196,7 +198,6 @@ pub async fn add_device( ctx: TunnelContext, AddDeviceParams { subnet, name, ip }: AddDeviceParams, ) -> Result<(), Error> { - let config = WgConfig::generate(name); let server = ctx .db .mutate(|db| { @@ -223,18 +224,21 @@ pub async fn add_device( if ip.octets()[3] == 0 || ip.octets()[3] == 255 { return Err(Error::new(eyre!("invalid ip"), ErrorKind::InvalidRequest)); } + if ip == subnet.addr() { + return Err(Error::new(eyre!("invalid ip"), ErrorKind::InvalidRequest)); + } if !subnet.contains(&ip) { return Err(Error::new( eyre!("ip not in subnet"), ErrorKind::InvalidRequest, )); } - clients.insert(ip, config).map_or(Ok(()), |_| { - Err(Error::new( - eyre!("ip already in use"), - ErrorKind::InvalidRequest, - )) - }) + let client = clients + .entry(ip) + .or_insert_with(|| WgConfig::generate(name.clone())); + client.name = name; + + Ok(()) })?; db.as_wg().de() }) @@ -247,12 +251,12 @@ pub async fn add_device( #[serde(rename_all = "camelCase")] pub struct RemoveDeviceParams { subnet: Ipv4Net, - device: Ipv4Addr, + ip: Ipv4Addr, } pub async fn remove_device( ctx: TunnelContext, - RemoveDeviceParams { subnet, device }: RemoveDeviceParams, + RemoveDeviceParams { subnet, ip }: RemoveDeviceParams, ) -> Result<(), Error> { let server = ctx .db @@ -262,8 +266,8 @@ pub async fn remove_device( .as_idx_mut(&subnet) .or_not_found(&subnet)? .as_clients_mut() - .remove(&device)? - .or_not_found(&device)?; + .remove(&ip)? + .or_not_found(&ip)?; db.as_wg().de() }) .await diff --git a/core/startos/src/tunnel/context.rs b/core/startos/src/tunnel/context.rs index f646c343b..421e8e2e6 100644 --- a/core/startos/src/tunnel/context.rs +++ b/core/startos/src/tunnel/context.rs @@ -40,12 +40,6 @@ use crate::util::collections::OrdMapIterMut; use crate::util::io::read_file_to_string; use crate::util::sync::{SyncMutex, Watch}; -#[cfg(all(feature = "tunnel", not(feature = "test")))] -const EMBEDDED_TUNNEL_UI_ROOT: Dir<'_> = - include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static"); -#[cfg(not(all(feature = "tunnel", not(feature = "test"))))] -const EMBEDDED_TUNNEL_UI_ROOT: Dir<'_> = Dir::new("", &[]); - #[derive(Debug, Clone, Default, Deserialize, Serialize, Parser)] #[serde(rename_all = "kebab-case")] #[command(rename_all = "kebab-case")] diff --git a/sdk/base/lib/osBindings/SignerInfo.ts b/sdk/base/lib/osBindings/SignerInfo.ts index 76cbdafce..7e7aa2588 100644 --- a/sdk/base/lib/osBindings/SignerInfo.ts +++ b/sdk/base/lib/osBindings/SignerInfo.ts @@ -1,3 +1,9 @@ // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { AnyVerifyingKey } from "./AnyVerifyingKey" +import type { ContactInfo } from "./ContactInfo" -export type SignerInfo = { name: string } +export type SignerInfo = { + name: string + contact: Array + keys: Array +} diff --git a/sdk/base/lib/util/ip.ts b/sdk/base/lib/util/ip.ts index c7a26102e..79174a24d 100644 --- a/sdk/base/lib/util/ip.ts +++ b/sdk/base/lib/util/ip.ts @@ -1,15 +1,19 @@ export class IpAddress { - readonly octets: number[] - constructor(readonly address: string) { + protected constructor( + readonly octets: number[], + readonly address: string, + ) {} + static parse(address: string): IpAddress { + let octets if (address.includes(":")) { - this.octets = new Array(16).fill(0) + octets = new Array(16).fill(0) const segs = address.split(":") let idx = 0 let octIdx = 0 while (segs[idx]) { const num = parseInt(segs[idx], 16) - this.octets[octIdx++] = num >> 8 - this.octets[octIdx++] = num & 255 + octets[octIdx++] = num >> 8 + octets[octIdx++] = num & 255 idx += 1 } const lastSegIdx = segs.length - 1 @@ -18,21 +22,46 @@ export class IpAddress { octIdx = 15 while (segs[idx]) { const num = parseInt(segs[idx], 16) - this.octets[octIdx--] = num & 255 - this.octets[octIdx--] = num >> 8 + octets[octIdx--] = num & 255 + octets[octIdx--] = num >> 8 idx -= 1 } } } else { - this.octets = address.split(".").map(Number) - if (this.octets.length !== 4) throw new Error("invalid ipv4 address") + octets = address.split(".").map(Number) + if (octets.length !== 4) throw new Error("invalid ipv4 address") } - if (this.octets.some((o) => o >= 256)) { + if (octets.some((o) => o > 255)) { throw new Error("invalid ip address") } + return new IpAddress(octets, address) } - static parse(address: string): IpAddress { - return new IpAddress(address) + static fromOctets(octets: number[]) { + if (octets.length == 4) { + if (octets.some((o) => o > 255)) { + throw new Error("invalid ip address") + } + return new IpAddress(octets, octets.join(".")) + } else if (octets.length == 16) { + if (octets.some((o) => o > 255)) { + throw new Error("invalid ip address") + } + let pre = octets.slice(0, 8) + while (pre[pre.length - 1] == 0) { + pre.pop() + } + let post = octets.slice(8) + while (post[0] == 0) { + post.unshift() + } + if (pre.length + post.length == 16) { + return new IpAddress(octets, octets.join(":")) + } else { + return new IpAddress(octets, pre.join(":") + "::" + post.join(":")) + } + } else { + throw new Error("invalid ip address") + } } isIpv4(): boolean { return this.octets.length === 4 @@ -43,20 +72,79 @@ export class IpAddress { isPublic(): boolean { return this.isIpv4() && !PRIVATE_IPV4_RANGES.some((r) => r.contains(this)) } + add(n: number): IpAddress { + let octets = [...this.octets] + n = Math.floor(n) + for (let i = octets.length - 1; i >= 0; i--) { + octets[i] += n + if (octets[i] > 255) { + n = octets[i] >> 8 + octets[i] &= 255 + } else { + break + } + } + if (octets[0] > 255) { + throw new Error("overflow incrementing ip") + } + return IpAddress.fromOctets(octets) + } + sub(n: number): IpAddress { + let octets = [...this.octets] + n = Math.floor(n) + for (let i = octets.length - 1; i >= 0; i--) { + octets[i] -= n + if (octets[i] < 0) { + n = Math.ceil(Math.abs(octets[i]) / 256) + octets[i] = ((octets[i] % 256) + 256) % 256 + } else { + break + } + } + if (octets[0] < 0) { + throw new Error("underflow decrementing ip") + } + return IpAddress.fromOctets(octets) + } + cmp(other: string | IpAddress): -1 | 0 | 1 { + if (typeof other === "string") other = IpAddress.parse(other) + const len = Math.max(this.octets.length, other.octets.length) + for (let i = 0; i < len; i++) { + const left = this.octets[i] || 0 + const right = other.octets[i] || 0 + if (left > right) { + return 1 + } else if (left < right) { + return -1 + } + } + return 0 + } } export class IpNet extends IpAddress { - readonly prefix - constructor(readonly ipnet: string) { - const [address, prefixStr] = ipnet.split("/", 2) - super(address) - this.prefix = Number(prefixStr) + private constructor( + octets: number[], + readonly prefix: number, + address: string, + readonly ipnet: string, + ) { + super(octets, address) + } + static fromIpPrefix(ip: IpAddress, prefix: number): IpNet { + if (prefix > ip.octets.length * 8) { + throw new Error("invalid prefix") + } + return new IpNet(ip.octets, prefix, ip.address, `${ip.address}/${prefix}`) } static parse(ipnet: string): IpNet { - return new IpNet(ipnet) + const [address, prefixStr] = ipnet.split("/", 2) + const ip = IpAddress.parse(address) + const prefix = Number(prefixStr) + return IpNet.fromIpPrefix(ip, prefix) } contains(address: string | IpAddress): boolean { - if (typeof address === "string") address = new IpAddress(address) + if (typeof address === "string") address = IpAddress.parse(address) if (this.octets.length !== address.octets.length) return false let prefix = this.prefix let idx = 0 @@ -68,20 +156,52 @@ export class IpNet extends IpAddress { prefix -= 8 } if (prefix === 0 || idx >= this.octets.length) return true - const mask = 255 << prefix + const mask = 255 ^ (255 >> prefix) return (this.octets[idx] & mask) === (address.octets[idx] & mask) } + zero(): IpAddress { + let octets: number[] = [] + let prefix = this.prefix + for (let idx = 0; idx < this.octets.length; idx++) { + if (prefix >= 8) { + octets[idx] = this.octets[idx] + prefix -= 8 + } else { + const mask = 255 ^ (255 >> prefix) + octets[idx] = this.octets[idx] & mask + prefix = 0 + } + } + + return IpAddress.fromOctets(octets) + } + last(): IpAddress { + let octets: number[] = [] + let prefix = this.prefix + for (let idx = 0; idx < this.octets.length; idx++) { + if (prefix >= 8) { + octets[idx] = this.octets[idx] + prefix -= 8 + } else { + const mask = 255 >> prefix + octets[idx] = this.octets[idx] | mask + prefix = 0 + } + } + + return IpAddress.fromOctets(octets) + } } export const PRIVATE_IPV4_RANGES = [ - new IpNet("127.0.0.0/8"), - new IpNet("10.0.0.0/8"), - new IpNet("172.16.0.0/12"), - new IpNet("192.168.0.0/16"), + IpNet.parse("127.0.0.0/8"), + IpNet.parse("10.0.0.0/8"), + IpNet.parse("172.16.0.0/12"), + IpNet.parse("192.168.0.0/16"), ] -export const IPV4_LOOPBACK = new IpNet("127.0.0.0/8") -export const IPV6_LOOPBACK = new IpNet("::1/128") -export const IPV6_LINK_LOCAL = new IpNet("fe80::/10") +export const IPV4_LOOPBACK = IpNet.parse("127.0.0.0/8") +export const IPV6_LOOPBACK = IpNet.parse("::1/128") +export const IPV6_LINK_LOCAL = IpNet.parse("fe80::/10") -export const CGNAT = new IpNet("100.64.0.0/10") +export const CGNAT = IpNet.parse("100.64.0.0/10") diff --git a/web/projects/start-tunnel/src/app/routes/home/routes/devices/add.ts b/web/projects/start-tunnel/src/app/routes/home/routes/devices/add.ts index 822829221..fc44325e4 100644 --- a/web/projects/start-tunnel/src/app/routes/home/routes/devices/add.ts +++ b/web/projects/start-tunnel/src/app/routes/home/routes/devices/add.ts @@ -34,7 +34,13 @@ import { TuiForm } from '@taiga-ui/layout' import { injectContext, PolymorpheusComponent } from '@taiga-ui/polymorpheus' import { ApiService } from 'src/app/services/api/api.service' -import { getIp, DeviceData, MappedSubnet, subnetValidator } from './utils' +import { + getIp, + DeviceData, + MappedSubnet, + subnetValidator, + ipInSubnetValidator, +} from './utils' @Component({ template: ` @@ -132,6 +138,11 @@ export class DevicesAdd { range ? `${name} (${range})` : '' protected onSubnet(subnet: MappedSubnet) { + this.form.controls.ip.clearValidators() + this.form.controls.ip.addValidators([ + Validators.required, + ipInSubnetValidator(subnet.range), + ]) const ip = getIp(subnet) if (ip) { diff --git a/web/projects/start-tunnel/src/app/routes/home/routes/devices/utils.ts b/web/projects/start-tunnel/src/app/routes/home/routes/devices/utils.ts index 959d99145..227d9b13a 100644 --- a/web/projects/start-tunnel/src/app/routes/home/routes/devices/utils.ts +++ b/web/projects/start-tunnel/src/app/routes/home/routes/devices/utils.ts @@ -30,25 +30,31 @@ export function subnetValidator({ value }: AbstractControl) { : { noHosts: 'No hosts available' } } -export function getIp({ clients, range }: MappedSubnet) { - const { prefix, octets } = new IpNet(range) - const used = Object.keys(clients).map(ip => - new utils.IpAddress(ip).octets.at(3), - ) +export const ipInSubnetValidator = (subnet: string | null = null) => { + const ipnet = subnet && utils.IpNet.parse(subnet) + return ({ value }: AbstractControl) => { + let ip: utils.IpAddress + try { + ip = utils.IpAddress.parse(value) + } catch (e) { + return { invalidIp: 'Not a valid IP Address' } + } + if (!ipnet) return null + return ipnet.contains(ip) + ? null + : { notInSubnet: `Address is not part of ${subnet}` } + } +} - for (let i = 2; i < totalHosts(prefix); i++) { - if (!used.includes(i)) { - return [...octets.slice(0, 3), i].join('.') +export function getIp({ clients, range }: MappedSubnet) { + const net = IpNet.parse(range) + const last = net.last() + + for (let ip = net.add(1); ip.cmp(last) === -1; ip.add(1)) { + if (!clients[ip.address]) { + return ip.address } } return '' } - -function totalHosts(prefix: number) { - // Handle special cases per RFC 3021 - if (prefix === 31) return 4 // point-to-point, 2 usable addresses - if (prefix === 32) return 3 // single host, 1 usable address - - return Math.pow(2, 32 - prefix) -} diff --git a/web/projects/start-tunnel/src/app/routes/home/routes/port-forwards/index.ts b/web/projects/start-tunnel/src/app/routes/home/routes/port-forwards/index.ts index 0564116d4..f28357df4 100644 --- a/web/projects/start-tunnel/src/app/routes/home/routes/port-forwards/index.ts +++ b/web/projects/start-tunnel/src/app/routes/home/routes/port-forwards/index.ts @@ -73,7 +73,7 @@ export default class PortForwards { map(g => Object.values(g) .flatMap( - val => val.ipInfo?.subnets.map(s => new utils.IpNet(s)) || [], + val => val.ipInfo?.subnets.map(s => utils.IpNet.parse(s)) || [], ) .filter(s => s.isIpv4() && s.isPublic()) .map(s => s.address), diff --git a/web/projects/start-tunnel/src/app/routes/home/routes/subnets/index.ts b/web/projects/start-tunnel/src/app/routes/home/routes/subnets/index.ts index 6663fedaa..2499a7c53 100644 --- a/web/projects/start-tunnel/src/app/routes/home/routes/subnets/index.ts +++ b/web/projects/start-tunnel/src/app/routes/home/routes/subnets/index.ts @@ -131,12 +131,13 @@ export default class Subnets { } private getNext(): string { - const used = this.subnets().map(s => new utils.IpNet(s.range).octets.at(2)) - - for (let i = 0; i < 256; i++) { - if (!used.includes(i)) { - return `10.59.${i}.0/24` - } + const last = + this.subnets() + .map(s => utils.IpNet.parse(s.range)) + .sort((a, b) => -1 * a.cmp(b))[0] ?? utils.IpNet.parse('10.58.255.0/24') + const next = utils.IpNet.fromIpPrefix(last.last().add(2), 24) + if (!next.isPublic()) { + return next.ipnet } // No recommendation if /24 subnets are used