Skip to content

Commit 40f83a4

Browse files
committed
Operator overloading for +
1 parent d46fed4 commit 40f83a4

File tree

12 files changed

+102
-66
lines changed

12 files changed

+102
-66
lines changed

include/Expr.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ struct Expr: Checkable, std::enable_shared_from_this<Expr> {
7070
}
7171
};
7272

73-
using Argument = std::variant<ExprPtr, VregPtr>;
74-
void compileCall(VregPtr, Function &, ScopePtr, FunctionPtr, std::initializer_list<Argument>, const ASTLocation &);
73+
using Argument = std::variant<Expr *, VregPtr>;
74+
void compileCall(VregPtr, Function &, ScopePtr, FunctionPtr, std::initializer_list<Argument>, const ASTLocation &,
75+
size_t = 1);
7576

7677
std::string stringify(const Expr *);
7778

include/Function.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class Function: public Makeable<Function> {
8787

8888
std::shared_ptr<Scope> newScope(int *id_out = nullptr);
8989

90-
VregPtr precolored(int reg);
90+
VregPtr precolored(int reg, bool bypass = false);
9191

9292
size_t addToStack(VariablePtr);
9393

@@ -183,6 +183,9 @@ class Function: public Makeable<Function> {
183183
bool isMatch(TypePtr return_type, const std::vector<TypePtr> &argument_types, const std::string &struct_name)
184184
const;
185185

186+
bool isMatch(TypePtr return_type, const std::vector<Type *> &argument_types, const std::string &struct_name)
187+
const;
188+
186189
TypePtr & getArgumentType(size_t) const;
187190

188191
Function & setStatic(bool);

include/Program.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct Program {
4444
void compile();
4545
size_t getStringID(const std::string &);
4646

47-
FunctionPtr getOperator(const std::vector<TypePtr> &, int, const ASTLocation & = {}) const;
47+
FunctionPtr getOperator(const std::vector<Type *> &, int, const ASTLocation & = {}) const;
4848
};
4949

5050
Program compileRoot(const ASTNode &, const std::string &filename);

include/Variable.h

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,31 @@ struct Type;
1515
struct WhyInstruction;
1616

1717
struct VirtualRegister: Checkable, std::enable_shared_from_this<VirtualRegister> {
18-
Function *function = nullptr;
19-
int id;
20-
int reg = -1;
21-
bool precolored = false;
22-
std::shared_ptr<Type> type;
18+
private:
19+
int reg = -1;
2320

24-
VirtualRegister(Function &, std::shared_ptr<Type> = nullptr);
25-
VirtualRegister(int id_, std::shared_ptr<Type> = nullptr);
26-
std::shared_ptr<VirtualRegister> init();
21+
public:
22+
Function *function = nullptr;
23+
int id;
24+
bool precolored = false;
25+
std::shared_ptr<Type> type;
2726

28-
virtual ~VirtualRegister() {}
27+
VirtualRegister(Function &, std::shared_ptr<Type> = nullptr);
28+
VirtualRegister(int id_, std::shared_ptr<Type> = nullptr);
29+
std::shared_ptr<VirtualRegister> init();
2930

30-
std::string regOrID(bool colored = false) const;
31-
bool special() const;
31+
virtual ~VirtualRegister() {}
3232

33-
WeakSet<BasicBlock> readingBlocks, writingBlocks;
34-
WeakSet<WhyInstruction> readers, writers;
33+
std::string regOrID(bool colored = false) const;
34+
bool special() const;
3535

36-
size_t getSize() const;
37-
virtual operator std::string() const { return regOrID(true); }
36+
WeakSet<BasicBlock> readingBlocks, writingBlocks;
37+
WeakSet<WhyInstruction> readers, writers;
38+
39+
size_t getSize() const;
40+
virtual operator std::string() const { return regOrID(true); }
41+
VirtualRegister & setReg(int, bool bypass = false);
42+
int getReg() const { return reg; }
3843
};
3944

4045
struct Variable: VirtualRegister, Makeable<Variable> {

src/BasicBlock.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ std::set<VregPtr> BasicBlock::gatherVariables() const {
88

99
for (const auto &instruction: instructions) {
1010
for (const auto &read: instruction->getRead())
11-
if (!read->is<Global>() && read->reg < 0)
11+
if (!read->is<Global>() && read->getReg() < 0)
1212
out.insert(read);
1313
for (const auto &written: instruction->getWritten())
14-
if (!written->is<Global>() && written->reg < 0)
14+
if (!written->is<Global>() && written->getReg() < 0)
1515
out.insert(written);
1616
}
1717

src/ColoringAllocator.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ ColoringAllocator::Result ColoringAllocator::attempt() {
4545

4646
for (const std::pair<const std::string, Node *> &pair: interference) {
4747
VregPtr ptr = pair.second->get<VregPtr>();
48-
if (ptr->reg == -1)
49-
ptr->reg = *pair.second->colors.begin();
48+
if (ptr->getReg() == -1)
49+
ptr->setReg(*pair.second->colors.begin());
5050
}
5151

5252
return Result::Success;
@@ -57,7 +57,7 @@ VregPtr ColoringAllocator::selectMostLive(int *liveness_out) const {
5757
VregPtr ptr;
5858
int highest = -1;
5959
for (const auto &var: function.virtualRegisters) {
60-
if (Why::isSpecialPurpose(var->reg) || !function.canSpill(var))
60+
if (Why::isSpecialPurpose(var->getReg()) || !function.canSpill(var))
6161
continue;
6262

6363
const int sum = function.getLiveIn(var).size() + function.getLiveOut(var).size();
@@ -87,7 +87,7 @@ void ColoringAllocator::makeInterferenceGraph() {
8787
size_t links = 0;
8888

8989
for (const auto &var: function.virtualRegisters) {
90-
if (var->reg == -1) {
90+
if (var->getReg() == -1) {
9191
const std::string id = std::to_string(var->id);
9292
if (!interference.hasLabel(id)) {
9393
Node &node = interference.addNode(id);
@@ -107,7 +107,7 @@ void ColoringAllocator::makeInterferenceGraph() {
107107

108108
for (const auto &var: function.virtualRegisters) {
109109
const int id = var->id;
110-
if (var->reg != -1)
110+
if (var->getReg() != -1)
111111
continue;
112112
for (const std::weak_ptr<BasicBlock> &bptr: var->writingBlocks) {
113113
const auto index = bptr.lock()->index;
@@ -130,14 +130,14 @@ void ColoringAllocator::makeInterferenceGraph() {
130130
auto &set = sets[block->index];
131131
for (const VregPtr &var: block->liveIn) {
132132
const int id = var->id;
133-
if (var->reg == -1 && set.count(id) == 0) {
133+
if (var->getReg() == -1 && set.count(id) == 0) {
134134
vec.push_back(id);
135135
set.insert(id);
136136
}
137137
}
138138
for (const VregPtr &var: block->liveOut) {
139139
const int id = var->id;
140-
if (var->reg == -1 && set.count(id) == 0) {
140+
if (var->getReg() == -1 && set.count(id) == 0) {
141141
vec.push_back(id);
142142
set.insert(id);
143143
}

src/Expr.cpp

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
#include "Scope.h"
1616
#include "WhyInstructions.h"
1717

18-
using Argument = std::variant<ExprPtr, VregPtr>;
19-
2018
void compileCall(VregPtr destination, Function &function, ScopePtr scope, FunctionPtr fnptr,
21-
std::initializer_list<Argument> arguments, const ASTLocation &location) {
19+
std::initializer_list<Argument> arguments, const ASTLocation &location, size_t multiplier) {
2220
if (arguments.size() != fnptr->argumentCount())
2321
throw GenericError(location, "Invalid number of arguments in call to " + fnptr->name + " at " +
2422
std::string(location) + ": " + std::to_string(arguments.size()) + " (expected " +
@@ -32,10 +30,10 @@ void compileCall(VregPtr destination, Function &function, ScopePtr scope, Functi
3230
size_t i = 0;
3331

3432
for (const auto &argument: arguments) {
35-
auto argument_register = function.precolored(i);
33+
auto argument_register = function.precolored(Why::argumentOffset + i);
3634
auto &fn_arg_type = *fnptr->getArgumentType(i);
37-
if (std::holds_alternative<ExprPtr>(argument)) {
38-
auto expr = std::get<ExprPtr>(argument);
35+
if (std::holds_alternative<Expr *>(argument)) {
36+
auto expr = std::get<Expr *>(argument);
3937
auto argument_type = expr->getType(scope);
4038
if (argument_type->isStruct())
4139
throw GenericError(expr->getLocation(), "Structs cannot be directly passed to functions; "
@@ -67,8 +65,13 @@ void compileCall(VregPtr destination, Function &function, ScopePtr scope, Functi
6765
for (size_t i = arguments.size(); 0 < i; --i)
6866
function.add<StackPopInstruction>(function.precolored(Why::argumentOffset + i - 1))->setDebug(debug);
6967

70-
if (!fnptr->returnType->isVoid() && destination)
71-
function.add<MoveInstruction>(function.precolored(Why::returnValueOffset), destination)->setDebug(debug);
68+
if (!fnptr->returnType->isVoid() && destination) {
69+
auto r0 = function.precolored(Why::returnValueOffset);
70+
if (multiplier == 1)
71+
function.add<MoveInstruction>(r0, destination)->setDebug(debug);
72+
else
73+
function.add<MultIInstruction>(r0, destination, multiplier)->setDebug(debug);
74+
}
7275
}
7376

7477
std::string stringify(const Expr *expr) {
@@ -396,29 +399,33 @@ std::ostream & operator<<(std::ostream &os, const Expr &expr) {
396399
}
397400

398401
void PlusExpr::compile(VregPtr destination, Function &function, ScopePtr scope, ssize_t multiplier) {
399-
VregPtr left_var = function.newVar(), right_var = function.newVar();
400402
auto left_type = left->getType(scope), right_type = right->getType(scope);
401-
402-
if (left_type->isPointer() && right_type->isInt()) {
403-
if (multiplier != 1)
404-
throw GenericError(getLocation(), "Cannot multiply in pointer arithmetic PlusExpr");
405-
auto *left_subtype = dynamic_cast<PointerType &>(*left_type).subtype;
406-
left->compile(left_var, function, scope, 1);
407-
right->compile(right_var, function, scope, left_subtype->getSize());
408-
} else if (left_type->isInt() && right_type->isPointer()) {
409-
if (multiplier != 1)
410-
throw GenericError(getLocation(), "Cannot multiply in pointer arithmetic PlusExpr");
411-
auto *right_subtype = dynamic_cast<PointerType &>(*right_type).subtype;
412-
left->compile(left_var, function, scope, right_subtype->getSize());
413-
right->compile(right_var, function, scope, 1);
414-
} else if (!(*left_type && *right_type)) {
415-
throw ImplicitConversionError(TypePtr(left_type->copy()), TypePtr(right_type->copy()), getLocation());
403+
if (auto fnptr = function.program.getOperator({left_type.get(), right_type.get()}, CMMTOK_PLUS, getLocation())) {
404+
compileCall(destination, function, scope, fnptr, {left.get(), right.get()}, getLocation(), multiplier);
416405
} else {
417-
left->compile(left_var, function, scope);
418-
right->compile(right_var, function, scope);
419-
}
406+
VregPtr left_var = function.newVar(), right_var = function.newVar();
407+
408+
if (left_type->isPointer() && right_type->isInt()) {
409+
if (multiplier != 1)
410+
throw GenericError(getLocation(), "Cannot multiply in pointer arithmetic PlusExpr");
411+
auto *left_subtype = dynamic_cast<PointerType &>(*left_type).subtype;
412+
left->compile(left_var, function, scope, 1);
413+
right->compile(right_var, function, scope, left_subtype->getSize());
414+
} else if (left_type->isInt() && right_type->isPointer()) {
415+
if (multiplier != 1)
416+
throw GenericError(getLocation(), "Cannot multiply in pointer arithmetic PlusExpr");
417+
auto *right_subtype = dynamic_cast<PointerType &>(*right_type).subtype;
418+
left->compile(left_var, function, scope, right_subtype->getSize());
419+
right->compile(right_var, function, scope, 1);
420+
} else if (!(*left_type && *right_type)) {
421+
throw ImplicitConversionError(TypePtr(left_type->copy()), TypePtr(right_type->copy()), getLocation());
422+
} else {
423+
left->compile(left_var, function, scope, multiplier);
424+
right->compile(right_var, function, scope, multiplier);
425+
}
420426

421-
function.add<AddRInstruction>(left_var, right_var, destination)->setDebug(*this);
427+
function.add<AddRInstruction>(left_var, right_var, destination)->setDebug(*this);
428+
}
422429
}
423430

424431
std::optional<ssize_t> PlusExpr::evaluate(ScopePtr scope) const {

src/Function.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ void Function::setArguments(const std::vector<std::pair<std::string, TypePtr>> &
125125
argument->init();
126126
argumentMap.emplace(argument_name, argument);
127127
if (i < Why::argumentCount) {
128-
argument->reg = Why::argumentOffset + i;
128+
argument->setReg(Why::argumentOffset + i);
129129
} else
130130
throw GenericError(getLocation(), "Functions with greater than " + std::to_string(Why::argumentCount) +
131131
" arguments are currently unsupported.");
@@ -199,7 +199,6 @@ void Function::compile() {
199199

200200
auto rt = precolored(Why::returnAddressOffset);
201201

202-
203202
if (!is_init) {
204203
closeScope();
205204
auto gp_regs = usedGPRegisters();
@@ -235,11 +234,11 @@ std::set<int> Function::usedGPRegisters() const {
235234
std::set<int> out;
236235
for (const auto &instruction: instructions) {
237236
for (const auto &var: instruction->getRead())
238-
if (Why::isGeneralPurpose(var->reg))
239-
out.insert(var->reg);
237+
if (Why::isGeneralPurpose(var->getReg()))
238+
out.insert(var->getReg());
240239
for (const auto &var: instruction->getWritten())
241-
if (Why::isGeneralPurpose(var->reg))
242-
out.insert(var->reg);
240+
if (Why::isGeneralPurpose(var->getReg()))
241+
out.insert(var->getReg());
243242
}
244243
return out;
245244
}
@@ -257,9 +256,9 @@ ScopePtr Function::newScope(int *id_out) {
257256
return new_scope;
258257
}
259258

260-
VregPtr Function::precolored(int reg) {
259+
VregPtr Function::precolored(int reg, bool bypass) {
261260
auto out = std::make_shared<VirtualRegister>(*this)->init();
262-
out->reg = reg;
261+
out->setReg(reg, bypass);
263262
out->precolored = true;
264263
return out;
265264
}
@@ -1238,6 +1237,15 @@ bool Function::isDeclaredOnly() const {
12381237
}
12391238

12401239
bool Function::isMatch(TypePtr return_type, const std::vector<TypePtr> &argument_types, const std::string &struct_name)
1240+
const {
1241+
std::vector<Type *> raw_pointers;
1242+
for (const auto &type: argument_types)
1243+
raw_pointers.push_back(type.get());
1244+
1245+
return isMatch(return_type, raw_pointers, struct_name);
1246+
}
1247+
1248+
bool Function::isMatch(TypePtr return_type, const std::vector<Type *> &argument_types, const std::string &struct_name)
12411249
const {
12421250
if (return_type && *returnType != *return_type)
12431251
return false;

src/Program.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ size_t Program::getStringID(const std::string &str) {
369369
return stringIDs[str] = stringIDs.size();
370370
}
371371

372-
FunctionPtr Program::getOperator(const std::vector<TypePtr> &types, int oper, const ASTLocation &location) const {
372+
FunctionPtr Program::getOperator(const std::vector<Type *> &types, int oper, const ASTLocation &location) const {
373373
if (operators.count(oper) == 0)
374374
return nullptr;
375375

src/Variable.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ size_t VirtualRegister::getSize() const {
3131
return type->getSize();
3232
}
3333

34+
VirtualRegister & VirtualRegister::setReg(int new_reg, bool bypass) {
35+
if (!bypass && (new_reg == 0 || new_reg == 1))
36+
throw std::out_of_range("Invalid register: " + std::to_string(new_reg));
37+
reg = new_reg;
38+
return *this;
39+
}
40+
3441
Variable::Variable(const std::string &name_, std::shared_ptr<Type> type_, Function &function_):
3542
VirtualRegister(function_, type_), name(name_) {}
3643

0 commit comments

Comments
 (0)