Skip to content

Commit 01460e0

Browse files
authored
ROCm: use GcnArch for mcpu and ApiVersion to select code object version (#6447)
1 parent 1228111 commit 01460e0

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/target/target_kind.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,28 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
172172
arch = ExtractIntWithPrefix(mcpu, "gfx");
173173
CHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
174174
} else {
175-
TVMRetValue version;
176-
if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &version)) {
177-
LOG(WARNING) << "Unable to detect ROCm version, default to \"-mcpu=gfx305\" instead";
178-
arch = 305;
175+
TVMRetValue val;
176+
if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kGcnArch, &val)) {
177+
LOG(WARNING) << "Unable to detect ROCm compute arch, default to \"-mcpu=gfx900\" instead";
178+
arch = 900;
179179
} else {
180-
arch = version.operator int();
180+
arch = val.operator int();
181181
}
182182
attrs.Set("mcpu", String("gfx") + std::to_string(arch));
183183
}
184184
// Update -mattr before ROCm 3.5:
185185
// Before ROCm 3.5 we needed code object v2, starting
186186
// with 3.5 we need v3 (this argument disables v3)
187-
if (arch < 305) {
187+
188+
TVMRetValue val;
189+
int version;
190+
if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) {
191+
LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5";
192+
version = 305;
193+
} else {
194+
version = val.operator int();
195+
}
196+
if (version < 305) {
188197
Array<String> mattr;
189198
if (attrs.count("mattr")) {
190199
mattr = Downcast<Array<String>>(attrs.at("mattr"));

0 commit comments

Comments
 (0)