diff --git a/core/startos/src/tunnel/api.rs b/core/startos/src/tunnel/api.rs index 22b284623..39cd10a38 100644 --- a/core/startos/src/tunnel/api.rs +++ b/core/startos/src/tunnel/api.rs @@ -160,6 +160,16 @@ pub async fn add_subnet( .db .mutate(|db| { let map = db.as_wg_mut().as_subnets_mut(); + if let Some(s) = map + .keys()? + .into_iter() + .find(|s| s != &subnet && (s.contains(&subnet) || subnet.contains(s))) + { + return Err(Error::new( + eyre!("{subnet} overlaps with existing subnet {s}"), + ErrorKind::InvalidRequest, + )); + } map.upsert(&subnet, || { Ok(WgSubnetConfig::new(InternedString::default())) })? diff --git a/sdk/base/lib/util/ip.ts b/sdk/base/lib/util/ip.ts index 026bdd3cc..1f2c25067 100644 --- a/sdk/base/lib/util/ip.ts +++ b/sdk/base/lib/util/ip.ts @@ -1,8 +1,11 @@ export class IpAddress { + private renderedOctets: number[] protected constructor( - readonly octets: number[], - readonly address: string, - ) {} + public octets: number[], + private renderedAddress: string, + ) { + this.renderedOctets = [...octets] + } static parse(address: string): IpAddress { let octets if (address.includes(":")) { @@ -120,14 +123,48 @@ export class IpAddress { } return 0 } + get address(): string { + if ( + this.renderedOctets.length === this.octets.length && + this.renderedOctets.every((o, idx) => o === this.octets[idx]) + ) { + // already rendered + } else if (this.octets.length === 4) { + this.renderedAddress = this.octets.join(".") + this.renderedOctets = [...this.octets] + } else if (this.octets.length === 16) { + const contigZeros = this.octets.reduce( + (acc, x, idx) => { + if (x === 0) { + acc.current++ + } else { + acc.current = 0 + } + if (acc.current > acc.end - acc.start) { + acc.end = idx + 1 + acc.start = acc.end - acc.current + } + return acc + }, + { start: 0, end: 0, current: 0 }, + ) + if (contigZeros.end - contigZeros.start >= 2) { + return `${this.octets.slice(0, contigZeros.start).join(":")}::${this.octets.slice(contigZeros.end).join(":")}` + } + this.renderedAddress = this.octets.join(":") + this.renderedOctets = [...this.octets] + } else { + console.warn("invalid octet length for IpAddress", this.octets) + } + return this.renderedAddress + } } export class IpNet extends IpAddress { private constructor( octets: number[], - readonly prefix: number, + public prefix: number, address: string, - readonly ipnet: string, ) { super(octets, address) } @@ -135,7 +172,7 @@ export class IpNet extends IpAddress { if (prefix > ip.octets.length * 8) { throw new Error("invalid prefix") } - return new IpNet(ip.octets, prefix, ip.address, `${ip.address}/${prefix}`) + return new IpNet(ip.octets, prefix, ip.address) } static parse(ipnet: string): IpNet { const [address, prefixStr] = ipnet.split("/", 2) @@ -143,8 +180,9 @@ export class IpNet extends IpAddress { const prefix = Number(prefixStr) return IpNet.fromIpPrefix(ip, prefix) } - contains(address: string | IpAddress): boolean { + contains(address: string | IpAddress | IpNet): boolean { if (typeof address === "string") address = IpAddress.parse(address) + if (address instanceof IpNet && address.prefix < this.prefix) return false if (this.octets.length !== address.octets.length) return false let prefix = this.prefix let idx = 0 @@ -191,6 +229,9 @@ export class IpNet extends IpAddress { return IpAddress.fromOctets(octets) } + get ipnet() { + return `${this.address}/${this.prefix}` + } } export const PRIVATE_IPV4_RANGES = [ diff --git a/web/projects/start-tunnel/src/app/routes/home/routes/subnets/index.ts b/web/projects/start-tunnel/src/app/routes/home/routes/subnets/index.ts index e395a0c8a..202832937 100644 --- a/web/projects/start-tunnel/src/app/routes/home/routes/subnets/index.ts +++ b/web/projects/start-tunnel/src/app/routes/home/routes/subnets/index.ts @@ -133,18 +133,21 @@ export default class Subnets { } private getNext(): string { - const current = this.subnets().map(s => - utils.IpNet.parse(s.range).octets.slice(0, 2).join('.'), - ) + const current = this.subnets().map(s => utils.IpNet.parse(s.range)) + const suggestion = utils.IpNet.parse('10.59.0.1/24') for (let i = 0; i < 256; i++) { - const first3 = `10.59.${Math.floor(Math.random() * 256)}` - if (!current.includes(first3)) { - return `${first3}.0/24` + suggestion.octets[2] = Math.floor(Math.random() * 256) + if ( + !current.some( + s => s.contains(suggestion), // inverse check unnecessary since we don't allow subnets smaller than /24 + ) + ) { + return suggestion.ipnet } } - // No recommendation if /24 subnets are used from 10.59 + // No recommendation if can't find a /24 from 10.59 in 256 random tries return '' } }