geosets_rs/sets/
mod.rs

1use ndarray::concatenate;
2use ndarray::prelude::*;
3use plotly::Layout;
4use plotly::Trace;
5use plotly::common::Mode;
6use plotly::{Plot, Scatter};
7
8use self::errors::SetOperationError;
9
10pub mod errors;
11pub mod hpolytope;
12pub mod interval;
13pub mod vpolytope;
14pub mod zonotope;
15
16pub trait GeoSet: Sized + Clone {
17    fn dim(&self) -> usize;
18    fn empty(&self) -> Result<bool, SetOperationError>;
19    fn degenerate(&self) -> bool;
20
21    // Static function
22    fn from_unit_box(dim: usize) -> Self;
23
24    fn to_vertices(&self) -> Result<Array2<f64>, SetOperationError>;
25    fn center(&self) -> Result<Array1<f64>, SetOperationError>;
26    fn support_function(
27        &self,
28        direction: Array1<f64>,
29    ) -> Result<(Array1<f64>, f64), SetOperationError>;
30    fn volume(&self) -> Result<f64, SetOperationError>;
31    fn contains_point(&self, point: &Array1<f64>) -> Result<bool, SetOperationError>;
32
33    // Operations
34    fn minkowski_sum_(&mut self, other: &Self) -> Result<(), SetOperationError>;
35    fn matmul_(&mut self, mat: &Array2<f64>) -> Result<(), SetOperationError>;
36    fn translate_(&mut self, vector: &Array1<f64>) -> Result<(), SetOperationError>;
37
38    fn minkowski_sum(&self, other: &Self) -> Result<Self, SetOperationError> {
39        let mut copy = self.clone();
40        copy.minkowski_sum_(other)?;
41        Ok(copy)
42    }
43    fn matmul(&self, mat: &Array2<f64>) -> Result<Self, SetOperationError> {
44        let mut copy = self.clone();
45        copy.matmul_(mat)?;
46        Ok(copy)
47    }
48    fn translate(&self, vector: &Array1<f64>) -> Result<Self, SetOperationError> {
49        let mut copy = self.clone();
50        copy.translate_(vector)?;
51        Ok(copy)
52    }
53
54    // Generic implementations
55    fn create_trace(
56        &self,
57        dim: (usize, usize),
58        name: Option<&str>,
59    ) -> Result<Box<dyn Trace>, SetOperationError> {
60        use crate::geometric_operations::order_vertices_clockwise;
61        let full_vertices = self.to_vertices()?;
62        let col_x = full_vertices.column(dim.0);
63        let col_y = full_vertices.column(dim.1);
64        let vertices_2d = ndarray::stack(Axis(1), &[col_x, col_y]).unwrap();
65
66        let vertices = order_vertices_clockwise(vertices_2d).unwrap();
67
68        let closed_vertices = ndarray::concatenate(
69            Axis(0),
70            &[vertices.view(), vertices.row(0).view().insert_axis(Axis(0))],
71        )
72        .unwrap();
73
74        let x = closed_vertices.column(dim.0).to_vec();
75        let y = closed_vertices.column(dim.1).to_vec();
76
77        let mut trace = Scatter::new(x, y)
78            .mode(Mode::LinesMarkers)
79            .fill(plotly::common::Fill::ToSelf)
80            .opacity(0.8);
81
82        if let Some(trace_name) = name {
83            trace = trace.name(trace_name);
84        }
85
86        Ok(trace)
87    }
88
89    fn plot(
90        &self,
91        dim: (usize, usize),
92        equal_axis: bool,
93        show: bool,
94    ) -> Result<Plot, SetOperationError> {
95        let mut plot = Plot::new();
96        let trace = self.create_trace(dim, None).unwrap();
97        plot.add_trace(trace);
98
99        if equal_axis {
100            let layout = Layout::new()
101                .x_axis(plotly::layout::Axis::new())
102                .y_axis(plotly::layout::Axis::new().scale_anchor("x"));
103            plot.set_layout(layout);
104        }
105
106        if show {
107            plot.show();
108        }
109
110        Ok(plot)
111    }
112
113    // Utils
114    fn _check_operand_dim(&self, dim: usize) -> Result<(), SetOperationError> {
115        if dim != self.dim() {
116            return Err(SetOperationError::DimensionMismatch {
117                expected: self.dim(),
118                got: dim,
119            });
120        }
121        Ok(())
122    }
123}