geosets_rs/sets/
zonotope.rs

1#![allow(unused)]
2use super::*;
3use crate::linalg_utils::{rank, sign, vector_all_close};
4use crate::qhull_wrapper::convex_hull_vertices;
5use good_lp::{Expression, Solution, SolverModel, constraint, default_solver, variable, variables};
6use itertools::Itertools;
7use ndarray::Shape;
8use ndarray_linalg::Determinant;
9use ndarray_rand::RandomExt;
10use ndarray_rand::rand_distr::{Exp1, Uniform};
11use qhull::Qh;
12use thiserror::Error;
13
14#[derive(Clone, Debug)]
15#[allow(non_snake_case)]
16pub struct Zonotope {
17    G: Array2<f64>,
18    c: Array1<f64>,
19}
20
21#[derive(Error, Debug)]
22pub enum ZonotopeError {
23    #[error("Dimensions of G {g_dim:?} and c {c_dim:?} do not match")]
24    DimensionMismatch { g_dim: (usize, usize), c_dim: usize },
25}
26
27#[allow(non_snake_case)]
28impl Zonotope {
29    pub fn new(G: Array2<f64>, c: Array1<f64>) -> Result<Zonotope, ZonotopeError> {
30        if G.dim().1 != c.dim() {
31            Err(ZonotopeError::DimensionMismatch {
32                g_dim: G.dim(),
33                c_dim: c.dim(),
34            })
35        } else {
36            Ok(Zonotope { G, c })
37        }
38    }
39
40    pub fn from_random(
41        dim: usize,
42        n_generators: usize,
43        zero_centered: bool,
44    ) -> Result<Zonotope, ZonotopeError> {
45        let mut G = Array2::random((n_generators, dim), Exp1);
46
47        let mut c = if zero_centered {
48            Array1::zeros(dim)
49        } else {
50            Array1::random(dim, Uniform::new(-0.2, 0.2))
51        };
52
53        let max_deviation = c.abs() + G.abs().sum_axis(Axis(0));
54        // get max of max_deviation as scalar
55        let scale = max_deviation.fold(0.0_f64, |a, &b| a.max(b));
56
57        if scale > 1. {
58            G /= scale;
59            c /= scale;
60        }
61
62        Ok(Zonotope { G, c })
63    }
64
65    pub fn n_generators(&self) -> usize {
66        self.G.nrows()
67    }
68
69    pub fn is_zero_centered(&self) -> bool {
70        self.c.iter().all(|&x| x.abs() < 1e-9)
71    }
72
73    pub fn zonotope_norm(&self, point: &Array1<f64>) -> Result<f64, SetOperationError> {
74        self._check_operand_dim(point.dim())?;
75
76        // if !self.is_zero_centered() {
77        //     return Err(SetOperationError::UnsupportedOperation {
78        //         message: "Zonotope must be zero-centered",
79        //     });
80        // }
81
82        let m = self.n_generators();
83        if self.degenerate() {
84            if vector_all_close(point, &self.c, 1e-9) {
85                return Ok(0.0);
86            } else {
87                return Ok(f64::INFINITY);
88            }
89        }
90
91        // Optimization problem
92
93        let mut vars = variables!();
94        let lambda = vars.add(variable().min(0));
95        let alpha: Vec<_> = (0..self.n_generators())
96            .map(|_| vars.add(variable()))
97            .collect();
98
99        let objective: Expression = lambda.into();
100        let mut problem = vars.minimise(objective).using(default_solver);
101
102        // G \alpha = x
103        for i in 0..self.dim() {
104            let g = &self.G.column(i);
105            // Iterate over dimensions, not generators
106            let expr: Expression = g
107                .iter()
108                .zip(&alpha)
109                .map(|(g_i, alpha_i)| *g_i * *alpha_i)
110                .sum();
111            problem = problem.with(expr.eq(point[i] - self.c[i]));
112        }
113
114        // \alpha \leq \lamdba, \alpha \geq -\lambda
115        for alpha_i in &alpha {
116            problem = problem.with(constraint!(*alpha_i <= lambda));
117            problem = problem.with(constraint!(*alpha_i >= -lambda));
118        }
119
120        match problem.solve() {
121            Ok(solution) => {
122                let lambda_val = solution.value(lambda);
123                Ok(lambda_val)
124            }
125            Err(_) => Ok(f64::INFINITY),
126        }
127    }
128}
129
130#[allow(non_snake_case)]
131impl GeoSet for Zonotope {
132    fn from_unit_box(dim: usize) -> Self {
133        let G = Array2::eye(dim);
134        let c = Array1::zeros(dim);
135        Zonotope::new(G, c).unwrap()
136    }
137
138    fn dim(&self) -> usize {
139        self.c.dim()
140    }
141
142    fn empty(&self) -> Result<bool, SetOperationError> {
143        Ok(false)
144    }
145
146    fn to_vertices(&self) -> Result<Array2<f64>, SetOperationError> {
147        let mut vertices = self.c.clone().into_shape_clone((1, self.dim())).unwrap();
148
149        for i in 0..self.n_generators() {
150            vertices = ndarray::concatenate(
151                Axis(0),
152                &[
153                    (&vertices + &self.G.row(i)).view(),
154                    (&vertices - &self.G.row(i)).view(),
155                ],
156            )
157            .unwrap();
158        }
159
160        // Compute convex hull using qhull -> automatically propagates error
161        let hull_vertices = convex_hull_vertices(&vertices)?;
162        Ok(hull_vertices)
163    }
164
165    fn center(&self) -> Result<Array1<f64>, SetOperationError> {
166        Ok(self.c.clone())
167    }
168
169    fn support_function(
170        &self,
171        direction: Array1<f64>,
172    ) -> Result<(Array1<f64>, f64), SetOperationError> {
173        self._check_operand_dim(direction.dim())?;
174
175        let projection = self.G.dot(&direction);
176        // signum is not correct here!
177        let projection_sign = sign(&projection);
178
179        let support_value = direction.dot(&self.c) + projection.abs().sum();
180        let support_vector = &self.c + projection_sign.dot(&self.G);
181
182        Ok((support_vector, support_value))
183    }
184
185    fn volume(&self) -> Result<f64, SetOperationError> {
186        if self.degenerate() {
187            return Ok(0.0);
188        }
189
190        let all_combinations = (0..self.n_generators()).combinations(self.dim());
191
192        let mut vol = 0.0;
193        for comb in all_combinations {
194            let submatrix = self.G.select(Axis(0), &comb);
195            vol += submatrix.det().unwrap().abs();
196        }
197        Ok(2.0_f64.powf(self.dim() as f64) * vol)
198    }
199
200    fn minkowski_sum_(&mut self, other: &Self) -> Result<(), SetOperationError> {
201        self._check_operand_dim(other.dim())?;
202        self.G = concatenate![Axis(0), self.G.clone(), other.G.clone()];
203        // self.G = ndarray::concatenate(Axis(1), &[self.G.view(), other.G.view()]);
204        self.c = &self.c + &other.c;
205        Ok(())
206    }
207
208    fn matmul_(&mut self, mat: &Array2<f64>) -> Result<(), SetOperationError> {
209        self._check_operand_dim(mat.dim().0)?;
210        self.c = mat.dot(&self.c);
211        self.G = self.G.dot(&mat.t());
212        Ok(())
213    }
214
215    fn translate_(&mut self, vector: &Array1<f64>) -> Result<(), SetOperationError> {
216        self._check_operand_dim(vector.dim())?;
217        self.c = &self.c + vector;
218        Ok(())
219    }
220
221    fn degenerate(&self) -> bool {
222        self.n_generators() == 0 || rank(&self.G).unwrap() < self.dim()
223    }
224
225    fn contains_point(&self, point: &Array1<f64>) -> Result<bool, SetOperationError> {
226        Ok(self.zonotope_norm(point)? <= 1.0 + 1e-9)
227    }
228}
229
230#[cfg(test)]
231#[allow(non_snake_case)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_zonotope_new() {
237        let _ = Zonotope::new(Array::ones((5, 2)), Array::zeros(2)).unwrap();
238        let zono = Zonotope::new(Array::eye(3), Array::zeros(2));
239
240        // Expect an error when unwrapping zono2
241        assert!(zono.is_err());
242    }
243
244    #[test]
245    fn test_matmul_rotation() {
246        let G = array![[1.0, 0.0], [0.0, 1.0]]; // Two generators: [1,0] and [0,1]
247        let c = array![2.0, 3.0]; // Center at (2,3)
248        let mut zono = Zonotope::new(G, c).unwrap();
249
250        // Apply a 90-degree rotation matrix
251        let rotation_90 = array![[0.0, -1.0], [1.0, 0.0]];
252
253        zono.matmul_(&rotation_90).unwrap();
254
255        // After rotation:
256        // - Center (2,3) should become (-3,2)
257        // - Generator [1,0] should become [0,1]
258        // - Generator [0,1] should become [-1,0]
259
260        let expected_c = array![-3.0, 2.0];
261        let expected_G = array![[0.0, 1.0], [-1.0, 0.0]];
262
263        assert!(zono.c.abs_diff_eq(&expected_c, 1e-10));
264        assert!(zono.G.abs_diff_eq(&expected_G, 1e-10));
265    }
266
267    #[test]
268    fn test_matmul_scaling() {
269        let G = array![[2.0, 1.0], [1.0, 3.0]];
270        let c = array![1.0, 2.0];
271        let mut zono = Zonotope::new(G, c).unwrap();
272
273        // Scale by 2 in x, 3 in y
274        let scale = array![[2.0, 0.0], [0.0, 3.0]];
275
276        zono.matmul_(&scale).unwrap();
277
278        // Center should be scaled
279        let expected_c = array![2.0, 6.0];
280        // Generators should be transformed: [2,1] -> [4,3], [1,3] -> [2,9]
281        let expected_G = array![[4.0, 3.0], [2.0, 9.0]];
282
283        assert!(zono.c.abs_diff_eq(&expected_c, 1e-10));
284        assert!(zono.G.abs_diff_eq(&expected_G, 1e-10));
285    }
286}