@@ -172,19 +172,28 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
172172 arch = ExtractIntWithPrefix (mcpu, " gfx" );
173173 CHECK (arch != -1 ) << " ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
174174 } else {
175- TVMRetValue version ;
176- if (!DetectDeviceFlag ({kDLROCM , 0 }, runtime::kApiVersion , &version )) {
177- LOG (WARNING) << " Unable to detect ROCm version , default to \" -mcpu=gfx305 \" instead" ;
178- arch = 305 ;
175+ TVMRetValue val ;
176+ if (!DetectDeviceFlag ({kDLROCM , 0 }, runtime::kGcnArch , &val )) {
177+ LOG (WARNING) << " Unable to detect ROCm compute arch , default to \" -mcpu=gfx900 \" instead" ;
178+ arch = 900 ;
179179 } else {
180- arch = version .operator int ();
180+ arch = val .operator int ();
181181 }
182182 attrs.Set (" mcpu" , String (" gfx" ) + std::to_string (arch));
183183 }
184184 // Update -mattr before ROCm 3.5:
185185 // Before ROCm 3.5 we needed code object v2, starting
186186 // with 3.5 we need v3 (this argument disables v3)
187- if (arch < 305 ) {
187+
188+ TVMRetValue val;
189+ int version;
190+ if (!DetectDeviceFlag ({kDLROCM , 0 }, runtime::kApiVersion , &val)) {
191+ LOG (WARNING) << " Unable to detect ROCm version, assuming >= 3.5" ;
192+ version = 305 ;
193+ } else {
194+ version = val.operator int ();
195+ }
196+ if (version < 305 ) {
188197 Array<String> mattr;
189198 if (attrs.count (" mattr" )) {
190199 mattr = Downcast<Array<String>>(attrs.at (" mattr" ));
0 commit comments