ZonoOpt v2.0.1
Loading...
Searching...
No Matches
GenUtilities.hpp
Go to the documentation of this file.
1#ifndef ZONOOPT_GENUTILITIES_HPP_
2#define ZONOOPT_GENUTILITIES_HPP_
3
15#include <vector>
16#include <sstream>
17#include <iostream>
18
19namespace ZonoOpt::detail
20{
21 // generated by Gemini
22 template<typename T>
23 void combinations_util(const std::vector<T>& elements, const size_t k, const size_t start_index,
24 std::vector<T>& current_combination,
25 std::vector<std::vector<T>>& result)
26 {
27 // Base case: Combination is complete (size k reached)
28 if (current_combination.size() == k)
29 {
30 result.push_back(current_combination);
31 return;
32 }
33
34 // Recursive step: Iterate over available elements
35 for (size_t i = start_index; i < elements.size(); ++i)
36 {
37 // 🌟 Pruning optimization: stop if there aren't enough elements left to form a k-combination
38 if (elements.size() - i < k - current_combination.size())
39 {
40 return;
41 }
42
43 // 1. Include elements[i]
44 current_combination.push_back(elements[i]);
45
46 // 2. Recurse (start from i + 1 to prevent repeats)
47 combinations_util(elements, k, i + 1, current_combination, result);
48
49 // 3. Backtrack (remove element to try the next possibility)
50 current_combination.pop_back();
51 }
52 }
53
54 // generated by Gemini
55 template<typename T>
56 std::vector<std::vector<T>> get_combinations(const std::vector<T>& input_set, const size_t k)
57 {
58 if (k > input_set.size())
59 {
60 return {}; // Cannot choose k elements from a smaller set
61 }
62
63 // Copy set elements to a vector. std::set ensures they are already sorted.
64 std::vector<T> elements(input_set.begin(), input_set.end());
65
66 std::vector<std::vector<T>> result;
67 std::vector<T> current_combination;
68
69 // Start the recursive process from the first element (index 0)
70 combinations_util(elements, k, 0, current_combination, result);
71 return result;
72 }
73
74 inline void print_str(std::stringstream& ss)
75 {
76#ifdef IS_PYTHON_ENV
77 py::print(ss.str());
78#else
79 std::cout << ss.str() << std::endl;
80#endif
81 ss.str("");
82 }
83}
84
85#endif