forked from dojoengine/origami
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vector.cairo
127 lines (109 loc) · 3.47 KB
/
vector.cairo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#[derive(Copy, Drop)]
struct Vector<T> {
data: Span<T>,
}
mod errors {
const INVALID_INDEX: felt252 = 'Vector: index out of bounds';
const INVALID_SIZE: felt252 = 'Vector: invalid size';
}
trait VectorTrait<T> {
fn new(data: Span<T>) -> Vector<T>;
fn get(ref self: Vector<T>, index: u8) -> T;
fn size(self: Vector<T>) -> u32;
fn dot(self: Vector<T>, vector: Vector<T>) -> T;
}
impl VectorImpl<T, +Mul<T>, +AddEq<T>, +Zeroable<T>, +Copy<T>, +Drop<T>,> of VectorTrait<T> {
fn new(data: Span<T>) -> Vector<T> {
Vector { data }
}
fn get(ref self: Vector<T>, index: u8) -> T {
*self.data.get(index.into()).expect(errors::INVALID_INDEX).unbox()
}
fn size(self: Vector<T>) -> u32 {
self.data.len()
}
fn dot(mut self: Vector<T>, mut vector: Vector<T>) -> T {
// [Check] Dimesions are compatible
assert(self.size() == vector.size(), errors::INVALID_SIZE);
// [Compute] Dot product in a loop
let mut value = Zeroable::zero();
loop {
match self.data.pop_front() {
Option::Some(x_value) => {
let y_value = vector.data.pop_front().unwrap();
value += *x_value * *y_value;
},
Option::None => { break value; },
};
}
}
}
impl VectorAdd<
T, +Mul<T>, +AddEq<T>, +Add<T>, +Zeroable<T>, +Copy<T>, +Drop<T>,
> of Add<Vector<T>> {
fn add(mut lhs: Vector<T>, mut rhs: Vector<T>) -> Vector<T> {
// [Check] Dimesions are compatible
assert(lhs.size() == rhs.size(), errors::INVALID_SIZE);
let mut values = array![];
let max_index = lhs.size();
let mut index: u8 = 0;
loop {
if max_index == index.into() {
break;
}
values.append(lhs.get(index) + rhs.get(index));
index += 1;
};
VectorTrait::new(values.span())
}
}
impl VectorSub<
T, +Mul<T>, +AddEq<T>, +Sub<T>, +Zeroable<T>, +Copy<T>, +Drop<T>,
> of Sub<Vector<T>> {
fn sub(mut lhs: Vector<T>, mut rhs: Vector<T>) -> Vector<T> {
// [Check] Dimesions are compatible
assert(lhs.size() == rhs.size(), errors::INVALID_SIZE);
let mut values = array![];
let max_index = lhs.size();
let mut index: u8 = 0;
loop {
if max_index == index.into() {
break;
}
values.append(lhs.get(index) - rhs.get(index));
index += 1;
};
VectorTrait::new(values.span())
}
}
#[cfg(test)]
mod tests {
// Core imports
use debug::PrintTrait;
// Local imports
use super::{Vector, VectorTrait};
impl I128Zeroable of Zeroable<i128> {
fn zero() -> i128 {
0
}
fn is_zero(self: i128) -> bool {
self == 0
}
fn is_non_zero(self: i128) -> bool {
self != 0
}
}
#[test]
fn test_vector_get() {
let mut vector: Vector = VectorTrait::new(array![1, 2, 3, 4].span());
assert(vector.get(0) == 1, 'Vector: get failed');
assert(vector.get(2) == 3, 'Vector: get failed');
}
#[test]
fn test_vector_dot_product() {
let vector1: Vector = VectorTrait::new(array![1, 2, 3].span());
let vector2: Vector = VectorTrait::new(array![4, 5, 6].span());
let result = vector1.dot(vector2);
assert(result == 32, 'Vector: dot product failed'); // 1*4 + 2*5 + 3*6 = 32
}
}